摘要
本文记录tensorflow的学习入门过程,主要是MNIST在tensorflow中完成的整个过程进行笔记的记录。
读取数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
| import tensorflow as tf import os
from tensorflow.examples.tutorials.mnist import input_data import scipy.misc import matplotlib.pyplot as plt import matplotlib.image as mpimg
import numpy as np
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
print(mnist.__dir__())
save_dir = 'MNIST_data/raw/' if os.path.exists(save_dir) is False: os.makedirs(save_dir)
for i in range(20): image_arry = mnist.train.images[i, :] image_arry = image_arry.reshape(28, 28) filename = save_dir + 'mnist_train_%d.jpg' % i scipy.misc.toimage(image_arry, cmin=0.0, cmax=1.0).save(filename)
fig = plt.figure() plotwindow = fig.add_subplot(111) plt.axis('off') for i in range(10): one_hot_label = mnist.train.labels[i, :] label = np.argmax(one_hot_label) print('mnist_train_%d.jpg label:%d' % (i, label)) file = mpimg.imread('MNIST_data/raw/mnist_train_%d.jpg' % i) plt.imshow(file, cmap='gray') plt.title(u'image-%i' % label, loc='left') plt.show() plt.clf() plt.close()
|
一般国内上google是上不了的,所以如果你先前没在MNIST_data/ 文件路径下放好这四个压缩包,一般会提示网络连接超时。此时自己去百度下载好这四个训练样本。
结果出来想问下这个数字到底是几啊,我没看出来,但是标签里写的是7