pbj0812의 코딩 일기

[TensorFlow] 01 tutorial_Basic_Classification 본문

인공지능 & 머신러닝/TensorFlow

[TensorFlow] 01 tutorial_Basic_Classification

pbj0812 2019. 7. 11. 23:47

이 글은 TensorFlow tutorial(링크)의 예제를 재구성한 글입니다.

TensorFlow의 keras를 이용하여 fashion mnist 데이터를 학습하고 예측하는 예제입니다.

학습에 필요한 부분만 정리하였으니 전체코드는 위 링크를 통해 보시기 바랍니다.

 

라이브러리 불러오기

import tensorflow as tf
from tensorflow import keras

import numpy as np
import matplotlib.pyplot as plt

- tensorflow의 keras 사용

데이터 불러오기

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

- fashion_mnist 데이터 사용

- load_data()를 이용하여 데이터를 다운 받으면서 각각의 변수(train_images...)에 데이터를 입력한다.

- 학습 데이터 60,000개 + 테스트 데이터 10,000개로 구성되어 있다.

- 레이블은 0~9로 총 10종류이다.

- 데이터는 28*28의 형태로 이루어져있다.

레이블 매핑

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

- label은 0~9로 이루어져 있기에 추후 매핑을 위해 미리 입력해준다.

데이터 확인

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()

- matplotlib 을 이용하여 그린 그림

- colorbar를 보았을 때 0~255의 수로 이루어져 있는 것을 확인할 수 있다.

- plt.grid(True)로 할 경우 아래 그림처럼 나타난다.

데이터 전처리

train_images = train_images / 255.0
test_images = test_images / 255.0

- 데이터가 0~255로 이루어져있기에 255로 나누어주어 정규화를 시킨다.

- 이 작업을 해주지 않을 경우 학습이 제대로 이루어지지 않는다.

모델 구성

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation = tf.nn.relu),
    keras.layers.Dense(10, activation = tf.nn.softmax)
])

- 입력값의 형태는 28 * 28의 형태로 들어가게 지정했다.

- hidden layer는 1층으로 이루어져있고 128개의 뉴런으로 구성되었다.

- 활성화함수는 ReLU 함수를 사용한다.

- 출력층에서는 softmax함수를 사용하였으며 10개의 값으로 출력된다.(전체 합은 1이다.)

모델 컴파일

model.compile(optimizer = 'adam',
             loss = 'sparse_categorical_crossentropy',
             metrics = ['accuracy'])

- sparse_categorical_crossentropy 는 정수형의 타겟(0, 1, 2 ...)을 목표로 할 때 사용한다.

- metrics=['accuracy'] 는 평가를 하기 위해 사용된다.

모델 학습

model.fit(train_images, train_labels, epochs = 5)

- 5번 학습

모델 평가

test_loss, test_acc = model.evaluate(test_images, test_labels)

- 정확도 87%

- 확실히 mnist 데이터로 훈련시킬때(98% 이상)보다 정확도가 낮은것을 알 수 있다.

Comments