본문으로 바로가기

데이터 사이언스 카테고리에서 저는 [Tensor Flow]라고 제목에 머릿말을 달고 연재를 시작했는데요. 처음에는 나이, 체중에 따른 혈중지방함량치를 선형회귀로 예측하는 예제[바로가기] 수행했었는데요. 그때 단층 신경망을 사용했었죠. 오늘은 MNIST 필기 숫자를 판독해볼려고 합니다.

언제나 그랬지만, 오늘은 특별히 글 앞 부분의 코드는 김성훈 교수님의 유명한 딥러닝과 텐서플로우 공개 강좌[바로가기]의 내용을 따르고 있습니다. 단, 글 후반부 숫자 데이터를 확인하는 과정은 알량한 지식으로 제가 살을 정말 쪼금 붙였습니다.

오늘의 결과는 좋지 않습니다. 그러나 이미지 인식률을 높이는 것이 목적이 아니라, 간단한 신경망으로 예제 하나를 해보는 것이 그 내용입니다. 먼저...

import tensorflow as tf
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from tensorflow.examples.tutorials.mnist import input_data
tf.set_random_seed(777)  # reproducibility

%matplotlib inline

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

필요한 모듈을 import 합니다. 이 글에서는 tensorflow를 1.5를 사용하고 있습니다. 그리고, 텐서플로우가 제공하는 mnist를 받을 수 있도록하고 있고, MNIST 데이터를 읽을때, 아예 one_hot으로 읽었습니다. MNIST의 필기 숫자는 어떻게 생겼을까요???

count = 0
nrows = ncols = 4
plt.figure(figsize=(12,8))
for n in range(0,8):
    count += 1
    plt.subplot(nrows, ncols, count)
    plt.imshow(mnist.test.images[n].reshape(28, 28), cmap='Greys', interpolation='nearest')

plt.tight_layout()
plt.show()

를 실행하면 나타나는

위 그림으로 확인할 수 있습니다. train 데이터 55000개, test 데이터 10000개로 되어 있는 MNIST 데이터의 제일 앞부분 8개만 확인한 것인데요. 필기 숫자들이 맞네요^^ (당연하지만)

# parameters
training_epochs = 15
batch_size = 100

# input place holders
X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

# weights & bias for nn layers
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.random_normal([10]))

# hypothesis
hypothesis = tf.matmul(X, W) + b

# define cost/loss & optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=hypothesis, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)

# Test model and check accuracy
prediction = tf.equal(tf.argmax(hypothesis, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))

앞 선 예제[바로가기]에서와 거의 같은 구조입니다. 단, 28*28 크기의 이미지를 784*1로 펼쳤기 때문에, X, W의 크기를 신경써야합니다. 그리고, 전체 학습 횟수를 의미하는 epochs 횟수는 15회, 빠른 학습 시간을 위해 한번에 처리하는 숫자 그림의 갯수를 의미하는 batch_size는 100으로 설정되어 있습니다. 활성함수 activation functino으로는 소프트맥스 softmax를 사용합니다. 숫자를 고르는 거니 당연한 선택이겠지만, 코스트는 활성함수로 softmax를 선택하고, cross_entropy를 적용한 후, 최종 결과로 숫자를 얻도록하는 softmax_cross_entropy_with_logits라는 함수를 하나 사용합니다. 이제

# Launch graph
sess = tf.Session()
sess.run(tf.global_variables_initializer())
    
# train my model
for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples / batch_size)

    pbar = tqdm(range(total_batch))
    for i in pbar:
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        c, _ = sess.run([cost, optimizer], feed_dict={X: batch_xs, Y: batch_ys})
        avg_cost += c / total_batch
        pbar.set_description("cost : %f" % avg_cost)
        
# Accuracy report
print('Accuracy:', sess.run(accuracy, feed_dict={X: mnist.test.images, Y: mnist.test.labels}))

학습을 하도록 하죠~ 이때, Cost를 계속 확인하기 위해 보통 print로 cost를 찍어보는데, 저는 [바로가기]에서 소개한 적이 있는 tqdm을 사용했습니다. tqdm을 사용하면 꽤 괜찮은 화면을 얻을 수 있거든요~^^

저렇게 cost의 변화를 각 epoch마다 쉽게 알 수 있습니다.

아무튼.. 최종 겨로가는 90.23%의 accuracy를 보여주네요... 흠... 뭐 학교때 90점 이상 받아본 경험이 잘 없는 저는 이 성적이 좋지만, 사실 MNIST 데이터는 뭐 99% 이상은 다들 훌쩍 넘으니 좋은 성과는 아니겠지요. 처음에도 이야기했지만, 오늘은 결과보다는 어떻게 사용할 것인가 입니다.^^ 

아무튼... test 데이터 10000개 중에서 accuracy가 90.23%가 나왔으면, 977개 쯤의 못 맞춘 데이터가 있겠죠... 그것도 확인을 좀 더 해보도록 하겠습니다.

index = []
ori = []
pred = []

labels = sess.run(tf.argmax(mnist.test.labels, 1))
predictions = sess.run(tf.argmax(hypothesis, 1), feed_dict={X: mnist.test.images})

for i in range(0,mnist.test.num_examples):
    if labels[i] != predictions[i]:
        index.append(i)
        ori.append(labels[i])
        pred.append(predictions[i])
        
fail_result = pd.DataFrame({'label':ori, 'predict':pred}, index=index)
fail_result.head()

위 코드는 틀린 숫자들의 정보만 모아둡니다.

test 데이터의 몇 번째(index) 숫자(label)를 뭐로 잘 못 예측(predict)했는지를 따로 기록해 두었습니다. 제일 첫 줄을 읽으면, mnist.test.image 7번째 숫자는 9인데, 5로 잘 못 해석했다는 것입니다. 간단히

plt.figure(figsize=(12,6))
plt.hist(fail_result['predict'], bins=10)
plt.xlabel('fault_prediction')
plt.grid()
plt.show()

히스토그램을 그려보겠습니다.

잘 못 된 예측에는 5로 잘 못 인식한 아이가 160여개 이상으로 가장 많았습니다. 아직 뭐 더 봐야겠지만, 어떤 숫자를 5로 잘 못 읽은 경우가 많다는 거겠죠.

plt.figure(figsize=(12,6))
sns.violinplot(x="label", y="predict", data=fail_result)
plt.xlabel('fault_prediction')
plt.grid()
plt.show()

좀 더 자세히 seaborn의 violinplot[바로가기]으로 그려보면

각 숫자가 어떻게 잘 못 해석되었는지를 조금 알 수 있습니다. 틀린 것 중에서, 0의 경우는 5로 가장 많이 오판을 했고, 7의 경우 9와 2로 가장 많이 오판했습니다. 좀 더...

plt.figure(figsize=(12,6))
sns.swarmplot(x="label", y="predict", data=fail_result)
sns.despine(offset=10, trim=True)
plt.xlabel('fault_prediction')
plt.grid()
plt.show()

swarmplot으로 자세히 보면

8의 경우는 오판된 범위가 골고루 이지만, 5의 경우는 8, 6, 4, 3으로 많이 오판한 것 같습니다. 그럼

verifying_data = fail_result.query('label == 5').sample(n=8).index

count = 0
nrows = ncols = 4
plt.figure(figsize=(12,8))
for n in verifying_data:
    count += 1
    plt.subplot(nrows, ncols, count)
    plt.imshow(mnist.test.images[n].reshape(28, 28), cmap='Greys', interpolation='nearest')
    tmp = "Label:" + str(fail_result['label'][n]) + ", Prediction:" + str(fail_result['predict'][n])
    plt.title(tmp)

plt.tight_layout()
plt.show()

위 코드로 5를 오판한 그림을 실제로 한 번 보겠습니다. 여기서는 pandas에서 label이 5인 숫자들을 8개만 랜덤하게 샘플을 얻어서 어떻게 생겼길래 틀린건지 한 번 보겠습니다.

이렇습니다. 흠. 위에서 세 번째는 5인데, 6이라고 틀릴 만도 합니다. 뭐 첫번째도 3이라고 틀릴 수 있을 것도 같습니다. 밑에 줄의 두 번째도 8로 볼 수도 있을 것 같습니다.^^. 사실 처음에도 이야기 드렸지만, 90%의 accuracy가 mnist의 경우 좋은게 절대 아닙니다. 그저 앞으로 몇 번 더 mnist 데이터를 볼 건데 어떻게 향상되어 가는지를 볼려고 하는 것과, 지금은 단층 신경망으로 그저 절차를 확인하는 것 정도로 생각하시면 됩니다.

n=1
plt.imshow(mnist.test.images[n:n+1].reshape(28, 28), cmap='Greys', interpolation='nearest');

추가로... 위 코드는 

이 이미지인데... 내가 만약 내 글씨를 저렇게 파일로 읽을 수 있으면,

test_img = mnist.test.images[n:n+1]
print(sess.run(tf.argmax(hypothesis, 1), feed_dict={X: test_img}))

저렇게 넣어 볼 수 있을 겁니다. 뭐 결과는 2라고 나오지만요^^


댓글을 달아 주세요

  1. BlogIcon luvholic 2018.03.02 09:10 신고

    숫자모양 결과값이 나오는게 신기하네요^^

  2. BlogIcon 즐거운 우리집 2018.03.02 10:26 신고

    날로 인식률은 향상되겠죠 ㅎ

  3. BlogIcon Bliss :) 2018.03.02 14:54 신고

    ㅎㄷㄷㄷ 단층신경망으로 필기 숫자를 읽다니!!! 저희가 사용하는 모든 프로그램의
    결과물을 위해 입력 및 실행하는 코드는 정말 어마어마한 것 같아요>.< 천재&황금손 같아 보여요!!!

  4. BlogIcon 귀요미디지 2018.03.02 16:41 신고

    몇일 포근하더니 오늘 다시 쌀쌀하네요
    감기조심하시고
    즐거운 금요일 되세요 ^^

  5. BlogIcon 휴식같은 친구 2018.03.02 16:54 신고

    신기한 코딩의 세계네요.
    잘 보고 갑니다.

  6. BlogIcon peterjun 2018.03.03 00:37 신고

    복잡하네요. 90%가 낮은 수치라니... ㅎㅎ
    주말 즐겁게 잘 보내세요. ^^

  7. BlogIcon 공수래공수거 2018.03.03 06:26 신고

    배울수록 재미있어지겠는데요? ㅎ
    즐거운 주말 되시기 바랍니다

  8. BlogIcon 핑구야 날자 2018.03.03 08:06 신고

    스마트폰으로 사진을 찍으면 번역이 되는 시스템도 가능하겠군요