pbj0812의 코딩 일기

[통계학] python을 이용한 최소제곱법과 경사하강법 구현 본문

Science/통계학

[통계학] python을 이용한 최소제곱법과 경사하강법 구현

pbj0812 2020. 10. 4. 02:57

0. 목표

 - python을 이용한 최소제곱법과 경사하강법 구현(1차식 한정)

1. 실습

 1) library 호출

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

 2) 데이터 생성

  - 대략 2x + 1 의 느낌으로 생성

df = pd.DataFrame({"x" : [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "y" : [3.1, 5.2, 7.1, 9.2, 11.2, 13.1, 15.2, 17.1, 18.9, 20.9]})

 3) 그림

plt.plot(df['x'], df['y'])
plt.axis([0, max(df['x']), 0, max(df['y'])])

 4) 최소제곱법 테스트용 함수 구현

  - 3x + 1

# inp 는 수
def test(inp):
    result = (3 * inp) + 1
    return result

 5) 최소제곱법

# inp는 데이터 프레임(x, y 순)
def LSM(inp):
    inp2 = np.array(inp)
    tmp_x = inp2[:, 0]
    tmp_y = inp2[:, 1]
    len_inp = len(tmp_x)
    result = []
    for i in range(len_inp):
        tmp = (tmp_y[i] - test(tmp_x[i])) ** 2
        result.append(tmp)
    result = sum(result) / 2
    return result

  - 테스트

print(LSM(df))

  - 결과

0.10999999999999999

 6) 경사하강법

  - 편미분(수기)을 통한 theta0 + theta1 * x의 theta0와 theta1 갱신식을 구현

  - 초기의 식은 3x + 1로 주고 갱신하는 형태

# inp는 데이터 프레임(x, y 순)
# num : 자연수
# lr : 양의 실수
# 식 : theta0 + theta1 * x
def GD(inp, num, lr):
    inp2 = np.array(inp)
    tmp_x = inp2[:, 0]
    tmp_y = inp2[:, 1]
    len_inp = len(tmp_x)
    theta0 = 1
    theta1 = 3
    result = []
    for i in range(num):
        tmp = 0
        tmp1 = 0
        LSM = 0
        for j in range(len_inp):
            tmp += (theta0 + (theta1 * tmp_x[j])) - tmp_y[j]
            tmp1 += ((theta0 + (theta1 * tmp_x[j])) - tmp_y[j]) * tmp_x[j]
            LSM += (tmp_y[j] - (theta0 + (theta1 * tmp_x[j]))) ** 2
        print(i + 1, "차 시도")
        print('theta0 : ', theta0)
        print('theta1 : ', theta1)
        print("LSM : ", LSM / 2)
        print('==================')
        theta0 = theta0 - (lr * tmp)
        theta1 = theta1 - (lr * tmp1)

  - 테스트

GD(df, 20, 0.001)

  - 결과

   * 주어진 인풋 데이터가 이상할 시 숫자가 밖으로 튀어나감

   * lr을 0.01만 주어도 값이 이상해짐

1 차 시도
theta0 :  1
theta1 :  3
LSM :  189.11000000000004
==================
2 차 시도
theta0 :  0.946
theta1 :  2.6185
LSM :  69.81626812500008
==================
3 차 시도
theta0 :  0.9135224999999999
theta1 :  2.3868475
LSM :  25.84777002189062
==================
4 차 시도
theta0 :  0.8941106624999999
theta1 :  2.246167475
LSM :  9.641904794148136
==================
5 차 시도
theta0 :  0.8826303447499999
theta1 :  2.1607169106875
LSM :  3.6685184215280033
==================
6 차 시도
theta0 :  0.8759646112146874
theta1 :  2.1087962311115627
LSM :  1.4665197444495024
==================
7 차 시도
theta0 :  0.8722211723914045
theta1 :  2.0772316285168033
LSM :  0.6545435450313584
==================
8 차 시도
theta0 :  0.8702512210990663
theta1 :  2.058025287056307
LSM :  0.3548897276426328
==================
9 차 시도
theta0 :  0.8693573180999787
theta1 :  2.04632173437918
LSM :  0.24406424682728295
==================
10 차 시도
theta0 :  0.869116049528124
theta1 :  2.039173214147697
LSM :  0.20283685332463877
==================
11 차 시도
theta0 :  0.8692703622547194
theta1 :  2.0347901439767866
LSM :  0.1872628318967262
==================
12 차 시도
theta0 :  0.869664200713449
theta1 :  2.032086068621714
LSM :  0.18114546642377874
==================
13 차 시도
theta0 :  0.8702028249321202
theta1 :  2.0304014011631146
LSM :  0.17851512446674014
==================
14 차 시도
theta0 :  0.8708287196188277
theta1 :  2.0293357063440487
LSM :  0.177171571180053
==================
15 차 시도
theta0 :  0.8715069685737167
theta1 :  2.0286458798225544
LSM :  0.1763038591084904
==================
16 차 시도
theta0 :  0.872216375497739
theta1 :  2.0281843328193165
LSM :  0.17561309065913633
==================
17 차 시도
theta0 :  0.8729440734376992
theta1 :  2.027861464031504
LSM :  0.17498909455120656
==================
18 차 시도
theta0 :  0.8736822521815895
theta1 :  2.0276228763403013
LSM :  0.17439125841287934
==================
19 차 시도
theta0 :  0.874426171461057
theta1 :  2.027435545079298
LSM :  0.1738046072358808
==================
20 차 시도
theta0 :  0.875172954767085
theta1 :  2.02727942079341
LSM :  0.17322361520561808
==================

2. 참고

 - 기초 수학으로 이해하는 머신러닝 알고리즘(타테이시 켄고)

Comments