본문으로 바로가기

지난번에 단층 신경망만 가지고 MNIST의 필기 숫자를 학습해서 90%의 accuracy가 나오는 것을 경험[바로가기]했는데요. 사실 MNIST로 테스트하시는 분들께서는 겨우 90%라고 하시겠지만 겨우 단층 신경망으로 해본거니까요... 이번에는 신경망의 층 수를 조금 늘리고, 가중치의 초기값을 구하는 것에 옵션을 하나 추가해 봅니다. 언제나그렇듯~~~ 이 글은 여러 유명한 고수님의 글을 따라한 거지요... 우와~~ 나도 해보니까 되는데요^^ 입니다.^^

언제나 그랬지만, 오늘은 특별히 글 앞 부분의 코드는 김성훈 교수님의 유명한 딥러닝과 텐서플로우 공개 강좌[바로가기]의 내용을 따르고 있습니다. 특히 오늘은 가중치의 초기값을 잡아주는 Xavier 초기화 방법을 설명해 주시는 자료[바로가기]를 참조했습니다.단, 글 후반부 숫자 데이터를 확인하는 과정은 알량한 지식으로 제가 살을 정말 쪼금 붙였습니다.

일단... 초기값을 잘 선정하는 것이 중요하다... 그렇지 않으면 레이어를 늘려도 큰 효과가 없더라 정도만 이해하고 전 코드의 결과를 보고 싶으니 그냥 진행했습니다. 언젠가 이렇게 기초를 잘 학습하지 않고 지나간 것에 후회하는 날이 오겠지만... 뭐 일단 급해요....

import tensorflow as tf
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from tensorflow.examples.tutorials.mnist import input_data
tf.set_random_seed(777)  # reproducibility

%matplotlib inline

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

사용하던 모듈을 import 하구요... MNIST 데이터도 받구요...

# parameters
training_epochs = 15
batch_size = 100

# input place holders
X = tf.placeholder(tf.float32, [None, 784])
Y = tf.placeholder(tf.float32, [None, 10])

# weights & bias for nn layers
W1 = tf.get_variable("W1", shape=[784, 256], 
                     initializer=tf.contrib.layers.xavier_initializer())
b1 = tf.Variable(tf.random_normal([256]))
L1 = tf.nn.relu(tf.matmul(X, W1) + b1)

W2 = tf.get_variable("W2", shape=[256, 256], 
                     initializer=tf.contrib.layers.xavier_initializer())
b2 = tf.Variable(tf.random_normal([256]))
L2 = tf.nn.relu(tf.matmul(L1, W2) + b2)

W3 = tf.get_variable("W3", shape=[256, 256], 
                     initializer=tf.contrib.layers.xavier_initializer())
b3 = tf.Variable(tf.random_normal([256]))
L3 = tf.nn.relu(tf.matmul(L2, W3) + b3)

W4 = tf.get_variable("W4", shape=[256, 256], 
                     initializer=tf.contrib.layers.xavier_initializer())
b4 = tf.Variable(tf.random_normal([256]))
L4 = tf.nn.relu(tf.matmul(L3, W4) + b4)

W5 = tf.get_variable("W5", shape=[256, 10], 
                     initializer=tf.contrib.layers.xavier_initializer())
b5 = tf.Variable(tf.random_normal([10]))

hypothesis = tf.matmul(L4, W5) + b5

이렇게 지난번 글[바로가기]과 동일하게 epoch와 batch_size는 잡고, 무려 5개 Layer로 가중치를 잡고... initializer에 xavier를 사용한다고 잡아주면 됩니다. 음.. 이게 끝입니다.^^

# define cost/loss & optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=hypothesis, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(cost)

# Test model and check accuracy
prediction = tf.equal(tf.argmax(hypothesis, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))

그리고 나서 softmax에 cross entropy로 loss를 잡고, Adam Optimizer를 사용하고, learning rate은 0.001로 했습니다.

# Launch graph
sess = tf.Session()
sess.run(tf.global_variables_initializer())
    
# train my model
for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples / batch_size)

    pbar = tqdm(range(total_batch))
    for i in pbar:
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        c, _ = sess.run([cost, optimizer], feed_dict={X: batch_xs, Y: batch_ys})
        avg_cost += c / total_batch
        pbar.set_description("cost : %f" % avg_cost)
        
# Accuracy report
print('Accuracy:', sess.run(accuracy, feed_dict={X: mnist.test.images, Y: mnist.test.labels}))

이렇게 하고 학습을 해주면...

헉.. Accuracy가 97.5%가 나오네요.. 흠.. 이것만으로 저렇게 향상되다니..ㅠㅠ. 

그러면.. 못 맞춘 경우가 어떤지 한 번 비교해보죠~

index = []
ori = []
pred = []

labels = sess.run(tf.argmax(mnist.test.labels, 1))
predictions = sess.run(tf.argmax(hypothesis, 1), feed_dict={X: mnist.test.images})

for i in range(0,mnist.test.num_examples):
    if labels[i] != predictions[i]:
        index.append(i)
        ori.append(labels[i])
        pred.append(predictions[i])
        
fail_result = pd.DataFrame({'label':ori, 'predict':pred}, index=index)

plt.figure(figsize=(12,6))
plt.hist(fail_result['predict'], bins=10)
plt.xlabel('fault_prediction')
plt.grid()
plt.show()

이렇게 해서 결과를 보죠. 이전에 단층 신경망[바로가기]으로 테스트했을 때는

이랬구요...

지금은 이렇습니다. 전체적으로 실패한 예측이 현저히 줄면서 분포값이 내려가 있네요...

plt.figure(figsize=(12,6))
sns.violinplot(x="label", y="predict", data=fail_result)
plt.xlabel('fault_prediction')
plt.grid()
plt.show()

바이올린 플랏으로 보니

4는 8이나 9로 오해(^^)를 많이 했네요^^ 특히 0은 6으로 많은 오류를 가졌구요...

plt.figure(figsize=(12,6))
sns.swarmplot(x="label", y="predict", data=fail_result)
sns.despine(offset=10, trim=True)
plt.xlabel('fault_prediction')
plt.grid()
plt.show()

그리고...

단층신경망만으로 학습했을때 오류 분포였구요 [바로가기] ... 오늘의 성과를 보면

이렇습니다. 확실히 줄어든 것이 보이죠...음.. 4에 대한 판단은 좀 확인을 해봐야 겠네요...

verifying_data = fail_result.query('label == 4').sample(n=8).index

count = 0
nrows = ncols = 4
plt.figure(figsize=(12,8))
for n in verifying_data:
    count += 1
    plt.subplot(nrows, ncols, count)
    plt.imshow(mnist.test.images[n].reshape(28, 28), cmap='Greys', interpolation='nearest')
    tmp = "Label:" + str(fail_result['label'][n]) + ", Prediction:" + str(fail_result['predict'][n])
    plt.title(tmp)

plt.tight_layout()
plt.show()

를 이용해서 4에 대해 좀 확인을 해볼께요...

네.. 9로 오해(^^)한 경우가 확실히 많에요^^


댓글을 달아 주세요

  1. BlogIcon 귀요미디지 2018.04.09 10:47 신고

    오늘 낮부터 날이 포근해진다 하네요~
    즐거운 한주의 시작
    월요일 되세요 ^^

  2. BlogIcon 드래곤포토 2018.04.09 12:21 신고

    즐거운 한주 보내세요 ^^

  3. BlogIcon 휴식같은 친구 2018.04.09 17:34 신고

    잘 보고 갑니다. 너무 어려워 보입니다.ㅎ
    즐거운 하루 되세요.

  4. BlogIcon 비키니짐(VKNY GYM) 2018.04.09 19:49 신고

    전문적인 내용이라 어렵네요~ ㅎㅎ
    오늘 하루 잘 마무리하세요~

  5. BlogIcon Deborah 2018.04.09 19:56 신고

    아주 심오하지만 지극히 이분야에 몸담으신 분이라면 용어 이해력이 뛰어날듯 싶네요.

  6. BlogIcon peterjun 2018.04.10 00:49 신고

    어렵지만 멋지네요.
    숫자 판독에 대한 정확도가 올라가면 더더욱 멋지겠어요. ㅋ

  7. BlogIcon IT넘버원 2018.04.10 02:22 신고

    오호 어렵지만 이런식으로 판독할 수 있군요.

  8. BlogIcon 공수래공수거 2018.04.10 06:43 신고

    이런일이 흥미로운게 저는 신기합니다 ㅎ
    기분좋은 하루 되세요^^

  9. BlogIcon 핑구야 날자 2018.04.10 07:00 신고

    호기심이 가는 포스툉이네요 잘 보고 갑니다

  10. BlogIcon 즐거운 우리집 2018.04.10 07:33 신고

    필기체는 암호 같아요 ㅎㅎㅎ

  11. BlogIcon 멜로요우 2018.04.10 08:38 신고

    뭔가 매번 볼때마다 대단한거같아요~ 공학쪽은 거의 잼병이라서 모르겠어요 ㅍ퓨 흑....

  12. BlogIcon 스티마 2018.04.10 10:03 신고

    와우 저도 어서 공부해서, 진도 따가 가고 싶네요.

  13. BlogIcon Bliss :) 2018.04.11 03:55 신고

    저희 딸의 글씨를 판독하며 기싸움하는 저의 모습 같아요~ㅎㅎㅎ 저는 딸에게 고치라고 했는데, PinkWink님은 오류를 줄여 정확도를 높이는 방향으로 스스로 해결하셨군요ㅎㅎ 어렵지만 조심히 보고 갑니다!!ㅎ 활기찬 하루 되세요^^