[파이썬 머신러닝 완벽 가이드] Scikit-learn 개요 및 붓꽃 분류 실습
Scikit-learn 소개
— 머신러닝을 위한 다양한 알고리즘과 개발을 위한 프레임워크 및 API 제공
- 모델 별 유기성이 잘 이루어져 있어 생산성을 높여주고 실전 환경에서도 이미 검증되어 많이 사용된다.
- 주로 numpy와 scipy 기반으로 구축
주요 API/모듈
Estimator 클래스
— 핵심 메서드
- 학습: fit()
- 예측: predict()
- Estimator 클래스에 포함된 모든 클래스가 fit()과 predict()를 통해 학습과 예측결과를 반환한다.
— Classifier 클래스 : 분류 알고리즘 구현 클래스
- DecisionTreeClassifier
- RandomForestClassifier
- GradientBoostingClassifier
- GaussianNB
- SVC
— Regressor 클래스 : 회귀 알고리즘 구현 클래스
- LinearRegression
- Ridge
- Lasso
- RandomForestRegressor
- GradientBoostingRegressor
주요 모듈
- 예제 데이터 : sklearn.datasets
- 데이터 분리, 검증 & 파라미터 튜닝 : sklearn.model_selection
- 교차검증을 위한 학습용/테스트용 데이터 분리, 그리드서치로 최적 파라미터 추출 등 API 제공
- 피처 처리
- sklearn.preprocessing : 데이터 전처리용 기능 제공
- sklearn.feature_selection : 알고리즘에 큰 영향을 미치는 피처를 우선순위대로 셀렉션 작업해주는 기능 제공
- sklearn.feature_extraction : 텍스트나 이미지 데이터의 벡터화된 피처 추출에 사용
- 피처 처리 & 차원 축소 : sklearn.decompostion
- 평가 : sklearn.metrics
- ML 알고리즘 (아주 다양하지만 그중 대표적인 몇가지만)
- sklearn.ensemble
- sklearn.linear_model
- sklearn.naive_bayes
- sklearn.neighbors
- sklearn.svm
- sklearn.tree
- sklearn.cluster
- 유틸리티 : sklearn.pipeline
- 피처 처리 등의 변환과 ML 알고리즘 학습, 예측 등을 함께 묶어 실행할 수 있는 유틸리티 제공
사이킷런을 이요한 붓꽃 데이터 분류
— 일명 머신러닝의 “Hello World”
- Sepal length & width, Petal length & width 네가지 feature를 사용해서 붓꽃의 품종을 분류 (Classification) 하는 지도학습의 기본적인 예제
- 피처(Feature) : 데이터 세트의 일반 속성. 타겟값을 제외한 나머지 속성을 모두 피처로 지칭
- 레이블 = 클래스 = 타겟(값) = 결정(값)
- 지도학습 시 데이터의 학습을 위해 주어지는 정답 데이터 == 타겟값 또는 결정값
- 지도학습 중 분류의 경우 결정값을 레이블 또는 클래스로 지칭
- 결국 다 같은 것을 의미한다
분류
— 지도학습(Supervised Learning)의 대표적인 방법
- 다향한 피처와 분류 결정값인 레이블 데이터로 모델을 학습한 뒤 별도의 테스트 데이터셋에서 미지의 레이블을 예측
- 즉, 지도학습은 명확한 정답이 주어진 데이터를 먼저 학습한 뒤 미지의 정답을 예측
- 주어지는 데이터는 학습 데이터와 테스트 데이터
붓꽃 데이터 분류 프로세스
- 데이터셋 분리 : 데이터를 학습 데이터와 테스트 데이터로 분리
- 모델 학습 : 학습 데이터를 기반으로 ML 알고리즘을 적용해서 모델을 학습
- 예측 수행 : 학습된 모델을 사용해서 테스트 데이터를 예측
- 평가 : 예측된 결과값과 테스트 데이터의 실제 결과값을 비교해서 모델 성능을 평가
학습과 테스트 데이터셋의 분리
train_test_split() : sklearn.model_selection 클래스 메서드
━ 주요 파라미터
- test_size : 전체 데이터에서 테스트 데이터셋 크기를 얼마로 샘플링할 것인지 결정. 디폴트는 0.25 (25%)
- train_size : 학습 데이터셋 크기 결정. test_size를 통상적으로 사용하기 때문에 잘 사용되지 않음
- shuffle : 데이터를 분리하기 전 미리 섞을 것인지 결정. 디폴트는 True
- random_state : 동일한 데이터셋 생성을 위해 부여하는 난수 값
실습
# 사이킷런 버전 확인
import sklearn
print(sklearn.__version__)
[OUT] 1.0.1
붓꽃 예측을 위한 사이킷런 필요 모듈 로딩
# 내장 붓꽃 데이터셋
from sklearn.datasets import load_iris
# 의사결정트리 모델
from sklearn.tree import DecisionTreeClassifier
# 데이터셋 분리 함수
from sklearn.model_selection import train_test_split
import pandas as pd
붓꽃 데이터셋 로딩
(사이킷런 데이터셋 종류에 관한 정보는 다음 참고: https://scikit-learn.org/stable/datasets/toy_dataset.html#iris-dataset)
# 붓꽃 데이터셋 로딩
iris = load_iris()
# 딕셔너리 형태
print(type(iris))
iris
[OUT] <class 'sklearn.utils.Bunch'>
{'data': array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.6, 1.4, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]]),
'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
'frame': None,
'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'),
'DESCR': '.. _iris_dataset:\\\\n\\\\nIris plants dataset\\\\n--------------------\\\\n\\\\n**Data Set Characteristics:**\\\\n\\\\n :Number of Instances: 150 (50 in each of three classes)\\\\n :Number of Attributes: 4 numeric, predictive attributes and the class\\\\n :Attribute Information:\\\\n - sepal length in cm\\\\n - sepal width in cm\\\\n - petal length in cm\\\\n - petal width in cm\\\\n - class:\\\\n - Iris-Setosa\\\\n - Iris-Versicolour\\\\n - Iris-Virginica\\\\n \\\\n :Summary Statistics:\\\\n\\\\n ============== ==== ==== ======= ===== ====================\\\\n Min Max Mean SD Class Correlation\\\\n ============== ==== ==== ======= ===== ====================\\\\n sepal length: 4.3 7.9 5.84 0.83 0.7826\\\\n sepal width: 2.0 4.4 3.05 0.43 -0.4194\\\\n petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\\\\n petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\\\\n ============== ==== ==== ======= ===== ====================\\\\n\\\\n :Missing Attribute Values: None\\\\n :Class Distribution: 33.3% for each of 3 classes.\\\\n :Creator: R.A. Fisher\\\\n :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\\\\n :Date: July, 1988\\\\n\\\\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\\\\nfrom Fisher\\\\'s paper. Note that it\\\\'s the same as in R, but not as in the UCI\\\\nMachine Learning Repository, which has two wrong data points.\\\\n\\\\nThis is perhaps the best known database to be found in the\\\\npattern recognition literature. Fisher\\\\'s paper is a classic in the field and\\\\nis referenced frequently to this day. (See Duda & Hart, for example.) The\\\\ndata set contains 3 classes of 50 instances each, where each class refers to a\\\\ntype of iris plant. One class is linearly separable from the other 2; the\\\\nlatter are NOT linearly separable from each other.\\\\n\\\\n.. topic:: References\\\\n\\\\n - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\\\\n Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\\\\n Mathematical Statistics" (John Wiley, NY, 1950).\\\\n - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\\\\n (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\\\\n - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\\\\n Structure and Classification Rule for Recognition in Partially Exposed\\\\n Environments". IEEE Transactions on Pattern Analysis and Machine\\\\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\\\\n - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions\\\\n on Information Theory, May 1972, 431-433.\\\\n - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II\\\\n conceptual clustering system finds 3 classes in the data.\\\\n - Many, many more ...',
'feature_names': ['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)'],
'filename': 'iris.csv',
'data_module': 'sklearn.datasets.data'}
# data: 4개의 피처(feature)들로 이루어진 numpy 배열
iris_data = iris.data
# target: 붓꽃 데이터셋의 레이블(결정 값) 데이터로 된 numpy배열
iris_label = iris.target
print('iris target값:', iris_label)
print('iris target명:', iris.target_names)
# 붓꽃 데이터셋을 자세히 보기 위해 DataFrame으로 변환
iris_df = pd.DataFrame(data=iris_data, columns=iris.feature_names)
# 레이블 컬럼 추가
iris_df['label'] = iris_label
iris_df.head(3)
[OUT] iris target값: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
iris target명: ['setosa' 'versicolor' 'virginica']
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | label | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
학습 데이터와 테스트 데이터 세트로 분리
# 데이터셋의 20%를 테스트 데이터로 사용
X_train, X_test, y_train, y_test = train_test_split(iris_data, iris_label,
test_size=0.2, random_state=11)
학습 데이터 세트로 학습(Train) 수행
# DecisionTreeClassifier 객체 생성
dt_clf = DecisionTreeClassifier(random_state=11)
# 학습 수행
dt_clf.fit(X_train, y_train)
테스트 데이터 세트로 예측(Predict) 수행
# 학습이 완료된 DecisionTreeClassifier 객체에서 테스트 데이터 세트로 예측 수행.
pred = dt_clf.predict(X_test)
pred
[OUT] array([2, 2, 1, 1, 2, 0, 1, 0, 0, 1, 1, 1, 1, 2, 2, 0, 2, 1, 2, 2, 1, 0,
0, 1, 0, 0, 2, 1, 0, 1])
예측 정확도 평가
from sklearn.metrics import accuracy_score
# y_test: 테스트 데이터셋의 타겟값, pred: 모델이 예측한 값
print('예측 정확도: {0:.4f}'.format(accuracy_score(y_test,pred)))
[OUT] 예측 정확도: 0.9333
댓글남기기