在完成了MNIST手写数字模型的训练之后,我们就可以使用训练好的模型进行预测手写数字了。这里还是使用MNIST数据集中所提供的测试数据。

可以仅仅对测试集的数据进行预测,并直接打印出来结果即可。但是为了和原图像进行对比,这里定义了一个可视化的函数,将原图像以及预测结果值进行显示,可以使结果更加直观。

在上述基础上加上下面代码就可以了。

# 对测试集的数据进行预测
prediction_result = sess.run(tf.argmax(pred, 1), feed_dict={x: mnist.test.images})


# 定义可视化函数
def plot_images_labels_prediction(images,  # 图像列表
                                  labels,  # 标签列表
                                  prediction,  # 预测值列表
                                  index,   # 从index个开始显示
                                  num=10):  # 缺省一次显示10幅
      fig = plt.gcf()   # 获取当前图表
      fig.set_size_inches(10, 12)  # 显示成英寸(1英寸等于2.54cm)
      if num > 25:
            num = 25   # 最多显示25幅图片
      for i in range(0, num):
            ax = plt.subplot(5, 5, i+1)  # 画多个子图(5*5)

            ax.imshow(np.reshape(images[index], (28, 28)), cmap='binary')  # 显示第index张图像

            title = "label=" + str(np.argmax(labels[index]))   # 构建图片上要显示的title
            if len(prediction) > 0:
                  title += ", predict=" + str(prediction[index])

            ax.set_title(title, fontsize=10)
            ax.set_xticks([])  # 不显示坐标轴
            ax.set_yticks([])
            index += 1
      plt.show()

# 从第11张照片开始显示,显示25张
plot_images_labels_prediction(mnist.test.images, mnist.test.labels, prediction_result, 10, 25)

可视化结果:
在这里插入图片描述
从上面的预测结果我们可以看出,只有个别图片预测错误,大部分的预测数值都是正确的。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐