3.2.1 tensorflow之MNIST

摘要

本文记录tensorflow的学习入门过程,主要是MNIST在tensorflow中完成的整个过程进行笔记的记录。

读取数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# coding: utf-8
import tensorflow as tf
import os
# 在不使用keras的情况下
from tensorflow.examples.tutorials.mnist import input_data
import scipy.misc
import matplotlib.pyplot as plt
import matplotlib.image as mpimg # mpimg 用于读取图片

import numpy as np
# 从MNIST_data/中读取数据,如果不存在就会自动下载
# 这个input_data在mnist文件夹下
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# print(mnist.train.images.shape)
# print(mnist.train.labels.shape)
# print(mnist.validation.images.shape)
# print(mnist.validation.labels.shape)
# print(mnist.test.images.shape)
# print(mnist.test.labels.shape)
# 查看
print(mnist.__dir__())
# print(dir(mnist))

# 把原始图片存在这个路径下
save_dir = 'MNIST_data/raw/'
if os.path.exists(save_dir) is False:
os.makedirs(save_dir)

# 保存图片
for i in range(20):
# 请注意,mnist.train.images[i, :]就表示第i张图片
image_arry = mnist.train.images[i, :]
image_arry = image_arry.reshape(28, 28)
# 保存文件的格式为:
# mnist_train_0.jpg, mnist_train_1.jpg, ..., mnist_train_19.jpg
filename = save_dir + 'mnist_train_%d.jpg' % i
# 将iamge_array 保存为图片
scipy.misc.toimage(image_arry, cmin=0.0, cmax=1.0).save(filename)

# 看前10张图片的样子
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
for i in range(10):
# 得到的都是one-hot 表示
one_hot_label = mnist.train.labels[i, :]
label = np.argmax(one_hot_label)
print('mnist_train_%d.jpg label:%d' % (i, label))
file = mpimg.imread('MNIST_data/raw/mnist_train_%d.jpg' % i)
plt.imshow(file, cmap='gray')
plt.title(u'image-%i' % label, loc='left')
plt.show()
plt.clf()
plt.close()

一般国内上google是上不了的,所以如果你先前没在MNIST_data/ 文件路径下放好这四个压缩包,一般会提示网络连接超时。此时自己去百度下载好这四个训练样本。

结果出来想问下这个数字到底是几啊,我没看出来,但是标签里写的是7

这个图像出来是数字7吗?

文章目录
  1. 1. 摘要
    1. 1.1. 读取数据
|