MNIST数据集转换为图片数据集的样例程序

it2023-01-08  54

一、编写背景

因为需要对接一个官方的编程API,本人需要自己按其要求搭建一个神经网络,以尝试调用某模块的工作。我参考了Tensorflow的参考书了解了MNIST数据集,然后我准备把MNIST数据集转换为图片格式,以适应API的要求。 同样,这个程序转化出的图片格式的MNIST数据集和标签集也非常适合初学者第一次搭建网络。

二、基础依赖

numpy,opencv,原始MNIST数据集

三、程序主体

# mnist数据集请自行下载,本程序默认数据集在./dataset的文件夹下 from tensorflow.examples.tutorials.mnist import input_data import cv2 import numpy as np mnist = input_data.read_data_sets("./dataset", one_hot=True) # 原数据集的训练集有55000个样本,在此只提取10000个,按需更改 IMAGE_NUM = 10000 print('dataset import done') def image_extract(): for i in range(0, IMAGE_NUM): # 提取长度784的图片像素向量 img = mnist.train.images[i] # 转换成28×28的[0,255]的整数矩阵,以方便cv2保存图片 img_re = (img.reshape(-1, 28) * 255).astype(int) cv2.imwrite('./dataset/images/'+str(i)+'.jpg', img_re ) # print('image ' + str(i) + ' extracted.') print('images extraction done') def label_extract(): labels = [] for i in range(0, IMAGE_NUM): # 提取出长度10的标签向量 lbl = list(mnist.train.labels[i]) # 我以[0-9]的整数进行保存了,实际上用原始的长度10的向量进行训练更合适,自行选择 lbl_num = lbl.index(1) labels = labels + [lbl_num] # print('label of image ' + str(i) + ' is ' + str(lbl_num)) # 保存为npy格式更方便读取 np.save('./dataset/label.npy', labels) print('labels extraction done') if __name__ == '__main__': image_extract() label_extract()
最新回复(0)