일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- 독후감
- 한빛미디어서평단
- MySQL
- Linux
- MATLAB
- 서평
- tensorflow
- 매틀랩
- matplotlib
- python visualization
- 파이썬 시각화
- 리눅스
- 시각화
- 텐서플로
- 통계학
- Ga
- 블로그
- SQL
- Tistory
- 서평단
- Visualization
- 월간결산
- 티스토리
- Python
- 딥러닝
- 파이썬
- Pandas
- Google Analytics
- Blog
- 한빛미디어
- Today
- Total
pbj0812의 코딩 일기
[Python] sklearn의 DecisionTree 사용 / Graphviz 설치 본문
[Python] sklearn의 DecisionTree 사용 / Graphviz 설치
pbj0812 2020. 3. 11. 01:440. 목표
- sklearn 의 DecisionTree를 이용한 Titanic 문제 해결
1. DecisionTreeClassifier(참고)
1) criterion : 분류 기준(default = 'gini')
2) max_depth : decision tree의 깊이 지정
3) min_samples_split : 최소 샘플 개수
4) min_samples_leaf : 최소 분류 수
5) max_features : 최대 피쳐 수
2. 데이터셋 준비
- kaggle 타이타닉 데이터 셋 다운로드(링크에서 titanic 검색)
3. 코드 작성
1) 데이터 선정
import pandas as pd
data = pd.read_csv('E:/수료증/인프런/밑바닥부터시작하는머신러닝/train.csv')
data2 = data[['Pclass', 'Sex', 'Survived']]
- Pclass : 승선권 클래스(1, 2, 3)
- Sex : 성별(male, female)
- Survived : 생존여부(1 : 생존, 0 : 사망) => Y 값
2) 데이터 전처리
(1) 원 핫 인코딩을 통해 문자열을 숫자로 변경
data2['Sex'] = data2['Sex'].replace({"male":0, "female":1})
data2.head()
(2) 결측치(null 의 유무) 확인
data2.isnull().sum()
(3) y 값 추출
y = data2['Survived'].values
(4) x 값 추출
del data2['Survived']
data2.head()
x = data2.values
(5) 정규화
from sklearn.preprocessing import MinMaxScaler
minmax_scaler = MinMaxScaler()
minmax_scaler.fit(x)
x = minmax_scaler.transform(x)
3) 모델 구성
- LogisticRegression과 DecisionTree를 동시에 사용
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.pipeline import Pipeline
from sklearn.pipeline import make_pipeline
algorithmes = [LogisticRegression(), DecisionTreeClassifier()]
c_params = [0.1, 5, 7, 10, 15, 20, 100]
params = []
params.append([{
"solver" : ['saga'],
"penalty" : ["l1"],
"C" : c_params
}, {
"solver" : ['liblinear'],
"penalty" : ["l2"],
"C" : c_params
}])
params.append({
"criterion" : ["gini", "entropy"],
"max_depth" : [10, 8, 7, 6, 5, 4, 3, 2],
"min_samples_leaf" : [1, 2, 3, 4, 5, 6, 7, 8, 9]
})
4) 모델 실행
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report, accuracy_score
scoring = ['accuracy']
estimator_results = []
for i, (estimator, params) in enumerate(zip(algorithmes, params)):
gs_estimator = GridSearchCV(refit="accuracy", estimator=estimator, param_grid=params, scoring=scoring, cv=5,verbose=1, n_jobs=4)
print(gs_estimator)
gs_estimator.fit(x, y)
estimator_results.append(gs_estimator)
5) 결과 확인
(1) 점수 비교
- LogisticRegression의 결과 : 0.7867564534231201
- DecisionTreeClassification의 결과 : 0.7732884399551067
estimator_results[0].best_score_
estimator_results[1].best_score_
(2) DecisionTreeClassification 결과에서 feature 중요도 확인
- 결과 : array([0.24987426, 0.75012574]) => Pclass, Sex 순
estimator_results[1].best_estimator_.feature_importances_
6) 그래프 그리기
- 이때, graph.write_png("titanic.png") 에서 에러가 발생하는 경우가 있을 때는 아래 참고의 경로 설정이 필요
* import는 가능하나, 에러 발생의 경우에 해당
import pydotplus
from sklearn.externals.six import StringIO
from sklearn import tree
best_tree = estimator_results[1].best_estimator_
column_names = data2.columns
dot_data = StringIO()
tree.export_graphviz(best_tree, out_file=dot_data, feature_names = column_names)
graph = pydotplus.pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png("titanic.png")
from IPython.core.display import Image
Image(filename='titanic.png')
4. 참고 문헌
5. 기타(graphviz 경로설정)
1) 링크 에서 zip 다운로드
2) 환경변수 설정1
- 제어판 -> 시스템 및 보안 -> 시스템 -> 고급 시스템 설정 -> 환경 변수 -> 시스템 변수 -> 새로 만들기 에서 1)에서 압축을 풀면 생성되는 폴더 내에서 bin -> dot.exe의 경로를 입력
3) 환경변수 설정2
- 제어판 -> 시스템 및 보안 -> 시스템 -> 고급 시스템 설정 -> 환경 변수 -> 시스템 변수 -> Path -> 편집 -> 새로만들기에서 1)에서 압축을 풀면 생성되는 폴더 내에서 bin의 경로를 입력
4) 확인
'ComputerLanguage_Program > PYTHON' 카테고리의 다른 글
[PYTHON] 네트워크 분석을 위한 networkx 예제 (0) | 2020.04.09 |
---|---|
[PYTHON] TensorFlow Certification 시험준비(PyCharm) (0) | 2020.04.04 |
[PYTHON] PyQt5 + pyinstaller를 사용한 twitter 크롤링 프로그램 제작 (0) | 2020.02.24 |
[PYTHON] OpenCV를 활용한 기생충 웹캠 어플리케이션 만들기 (0) | 2020.02.22 |
[PYTHON] OpenCV를 활용한 나만의 기생충 포스터 만들기 (0) | 2020.02.15 |