pbj0812의 코딩 일기

[통계학] CART 구현을 통한 TITANIC 변수 선택 본문

Science/통계학

[통계학] CART 구현을 통한 TITANIC 변수 선택

pbj0812 2020. 3. 10. 02:41

0. 목표

 - CART 알고리즘을 통해 우선적으로 분류되어야 할 변수를 선택

1. 이론

 1) Gini Index를 통해 데이터의 대상 속성을 얼마나 잘못 분류할지를 계산

 2) 각 속성별(male, female, 1, 2, 3...)로 계산하여 최소값을 계산

2. 데이터셋 준비

 1) kaggle 타이타닉 데이터 셋 다운로드(링크에서 titanic 검색)

 2) 데이터 전처리

import pandas as pd

data = pd.read_csv('E:/수료증/인프런/밑바닥부터시작하는머신러닝/train.csv')

data2 = data[['Pclass', 'Sex', 'Survived']]

  - Pclass : 승선권 클래스(1, 2, 3)

  - Sex : 성별(male, female)

  - Survived : 생존여부(1 : 생존, 0 : 사망) => Y 값

3. 시나리오

 1) Sex

  (1) male : 1 - Gini(male) - Gini(others)

  (2) female : 1 - Gini(female) - Gini(others)

 2) Pclass

  (1) 1 : 1 - Gini(1) - Gini(others)

  (2) 2 : 1 - Gini(2) - Gini(others)

  (3) 3 : 1 - Gini(3) - Gini(others)

4. 구현 (2. 2)에 이어서)

  - 함수의 입력값으로는 데이터셋, 변수, y값

  - selection 에는 변수별 하위 분류값(male or female, 1 or 2 or 3)들을 저장하고, selection2 에는 정답의 분류값(0, 1)을 저장

 - 저장 공간을 두 개로 나누어서 한 곳에는 해당 속성이 들어가고 다른 한 곳에는 나머지 속성이 들어가도록 설계

def CART(dataset, attribute, y):
    selection = list(set(dataset[attribute]))
    selection2 = list(set(dataset[y]))
    dataset_len = len(dataset)
    # 분류
    for i in selection:
        result = []
        result2 = []
        data1 = dataset[dataset[attribute] == i]
        data2 = dataset[dataset[attribute] != i]
        data1_len = len(data1)
        data2_len = len(data2)
        # 생존 분류
        for j in selection2:
            data1_2 = data1[data1[y] == j]
            data2_2 = data2[data2[y] == j]
            data1_2_len = len(data1_2)
            data2_2_len = len(data2_2)
            result.append((data1_2_len/data1_len)**2)
            result2.append((data2_2_len/data2_len)**2)
        Gini1 = 1 - sum(result)
        Gini2 = 1 - sum(result2)
        final = (data1_len/dataset_len)*Gini1 + (data2_len/dataset_len)*Gini2
        print(i, ' : ', final)

5. 결과

 1) Sex

  - male : 0.3333650003885904

  - female : 0.3333650003885904

 2) Pclass

  - 1 : 0.4343484224965707

  - 2 : 0.4688911437727624

  - 3 : 0.4238751054331502

CART(data2, 'Sex', 'Survived')
CART(data2, 'Pclass', 'Survived')

 => 가장 낮은 값인 male or female을 기준으로 나누는게 Pclass의 결과보다 좋은 결과를 얻을 가능성이 높음

6. 참고

 - 밑바닥부터 시작하는 머신러닝(최성철 교수님)

Comments