原文链接:https://geektutu.com/post/tensorflow-mnist-simplest.html
本文介绍了机器学习中的hello word ----------mnist 😗
原文采用:
python 3.6 tensorflow 1.4本文采用:(截至2020年10月最高支持tensorflow版本,部分代码略作修改)
python 3.8.5(64位) tensorflow cpu 2.3.0先上结果: loss: accuracy:
神经网络:
bias and weight distribution:
Histograms:
代码:
model.py
import tensorflow as tf from tensorflow.python.framework.ops import disable_eager_execution disable_eager_execution() class Network: def __init__(self): self.learning_rate = 0.001 self.global_step = tf.Variable(0, trainable=False, name="global_step") self.x = tf.compat.v1.placeholder(tf.float32, [None, 784], name="x") self.label = tf.compat.v1.placeholder(tf.float32, [None, 10], name="label") self.w = tf.Variable(tf.zeros([784, 10]), name="fc/weight") self.b = tf.Variable(tf.zeros([10]), name="fc/bias") self.y = tf.nn.softmax(tf.matmul(self.x, self.w) + self.b, name="y") self.loss = -tf.reduce_sum(self.label * tf.compat.v1.log(self.y + 1e-10)) self.train = tf.compat.v1.train.GradientDescentOptimizer(self.learning_rate).minimize( self.loss, global_step=self.global_step) predict = tf.equal(tf.argmax(self.label, 1), tf.argmax(self.y, 1)) self.accuracy = tf.reduce_mean(tf.cast(predict, "float")) # 创建 summary node # w, b 画直方图 # loss, accuracy画标量图 tf.compat.v1.summary.histogram('weight', self.w) tf.compat.v1.summary.histogram('bias', self.b) tf.compat.v1.summary.scalar('loss', self.loss) tf.compat.v1.summary.scalar('accuracy', self.accuracy)train.py
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from model import Network from tensorflow.python.framework.ops import disable_eager_execution disable_eager_execution() CKPT_DIR = 'ckpt' class Train: def __init__(self): self.net = Network() self.sess = tf.compat.v1.Session() self.sess.run(tf.compat.v1.global_variables_initializer()) self.data = input_data.read_data_sets('./data_set', one_hot=True) def train(self): batch_size = 64 train_step = 20000 step = 0 save_interval = 1000 saver = tf.compat.v1.train.Saver(max_to_keep=5) # merge所有的summary node merged_summary_op = tf.compat.v1.summary.merge_all() # 可视化存储目录为当前文件夹下的 log merged_writer = tf.compat.v1.summary.FileWriter("./log", self.sess.graph) ckpt = tf.train.get_checkpoint_state(CKPT_DIR) if ckpt and ckpt.model_checkpoint_path: saver.restore(self.sess, ckpt.model_checkpoint_path) # 读取网络中的global_step的值,即当前已经训练的次数 step = self.sess.run(self.net.global_step) print('Continue from') print(' -> Minibatch update : ', step) while step < train_step: x, label = self.data.train.next_batch(batch_size) _, loss, merged_summary = self.sess.run( [self.net.train, self.net.loss, merged_summary_op], feed_dict={self.net.x: x, self.net.label: label} ) step = self.sess.run(self.net.global_step) if step % 100 == 0: merged_writer.add_summary(merged_summary, step) if step % save_interval == 0: saver.save(self.sess, CKPT_DIR + '/model', global_step=step) print('%s/model-%d saved' % (CKPT_DIR, step)) def calculate_accuracy(self): test_x = self.data.test.images test_label = self.data.test.labels accuracy = self.sess.run(self.net.accuracy, feed_dict={self.net.x: test_x, self.net.label: test_label}) print("准确率: %.2f,共测试了%d张图片 " % (accuracy, len(test_label))) if __name__ == "__main__": app = Train() app.train() app.calculate_accuracy() # tensorboard --logdir=./logpredict.py
在这里插import tensorflow as tf import numpy as np from PIL import Image from tensorflow.python.framework.ops import disable_eager_execution disable_eager_execution() from model import Network # python 3.6 # tensorflow 1.4 # pillow(PIL) 4.3.0 # 使用tensorflow的模型来预测手写数字 # 输入是28 * 28像素的图片,输出是个具体的数字 CKPT_DIR = 'ckpt' class Predict: def __init__(self): self.net = Network() self.sess = tf.compat.v1.Session() self.sess.run(tf.compat.v1.global_variables_initializer()) # 加载模型到sess中 self.restore() def restore(self): saver = tf.compat.v1.train.Saver() ckpt = tf.compat.v1.train.get_checkpoint_state(CKPT_DIR) if ckpt and ckpt.model_checkpoint_path: saver.restore(self.sess, ckpt.model_checkpoint_path) else: raise FileNotFoundError("未保存任何模型") def predict(self, image_path): # 读图片并转为黑白的 img = Image.open(image_path).convert('L') flatten_img = np.reshape(img, 784) x = np.array([1 - flatten_img]) y = self.sess.run(self.net.y, feed_dict={self.net.x: x}) # 因为x只传入了一张图片,取y[0]即可 # np.argmax()取得独热编码最大值的下标,即代表的数字 print(image_path) print(' -> Predict digit', np.argmax(y[0])) if __name__ == "__main__": app = Predict() app.predict('./test_images/0.png') app.predict('./test_images/1.png') app.predict('./test_images/4.png') 入代码片测试图片:
文件设置: data_set 为训练的MNIST库,运行代码自动生成 ckpt为保存模型,运行代码自动生成 log为可视化路径,运行代码自动生成 test_images为识别图片的地址
在project路径运行tensorboard:
tensorboard --logdir=./log
浏览器访问localhost得到可视化结果,端口6006(具体见cmd运行结果):http://localhost:6006/
pycharm中显示结果:
train.py:
predict.py: