Keras和Tensorflow的结合相当紧密,使用Keras创建模型,然后使用Tensorflow的Dataset喂数据,训练结果保存为Tensorflow的Checkpoint都是可以的,这边给出一个使用tf.Dataset->Keras Model->Checkpoint的实现供参考记录 其他使用Pytorch,Keras Generator,Tensorflow2.0的实现方式参考我的其他博文。
import numpy as np import tensorflow as tf import os,sys,csv import cv2 as cv import matplotlib.pyplot as plt tf.enable_eager_execution() tf.VERSION AUTOTUNE = tf.data.experimental.AUTOTUNE IMAGE_COUNT_TRAIN = 0 IMAGE_COUNT_TEST = 0 BATCH_SIZE = 36 def prepare_dataset_trian_test(): def prepare_filenames_labels_all(): CSV_PATH = 'D:/ai_data/histopathologic-cancer-detection/train_labels.csv' TRAIN_IMAGE_FOLDER = 'D:/ai_data/histopathologic-cancer-detection/train' def read_csv(): with open(CSV_PATH) as f: reader = csv.reader(f) return list(reader) csv_read = read_csv() filenames = [] labels = [] for item in csv_read[1:]: filenames.append(TRAIN_IMAGE_FOLDER + '/' + item[0]+'.tif') labels.append(int(item[1])) # return filenames[:1000],labels[:1000] return filenames,labels def prepare_filenames_labels_train_validate(filenames_all,labels_all,validate_ratio): global IMAGE_COUNT_TRAIN global IMAGE_COUNT_TEST file_quant = len(filenames_all) file_quant_train = int(float(file_quant)*(1.0-validate_ratio)) # file_quant_test = file_quant - file_quant_train train_filenames = filenames_all[0:file_quant_train] train_labels = labels_all[0:file_quant_train] test_filenames = filenames_all[file_quant_train:] test_labels = labels_all[file_quant_train:] IMAGE_COUNT_TRAIN = len(train_filenames) IMAGE_COUNT_TEST = len(test_filenames) return train_filenames,train_labels,test_filenames,test_labels filenames_all,labels_all = prepare_filenames_labels_all() train_filenames,train_labels,test_filenames,test_labels = prepare_filenames_labels_train_validate(filenames_all,labels_all,0.2) def image_read_cv2(filename,label): image_decoded = cv.imread(filename.numpy().decode(), 1) return image_decoded, label def image_resize(image_decoded,label): image_decoded.set_shape([None, None, None]) image_resized = tf.image.resize_images(image_decoded, [299, 299]) return (image_resized / 255.0 - 0.5)*2 , tf.one_hot(label,2,on_value=1.0,off_value=0.0,axis=-1) def prepare_train_ds(filenames,labels): global BATCH_SIZE paths_ds = tf.data.Dataset.from_tensor_slices(filenames) labels_ds = tf.data.Dataset.from_tensor_slices(labels) paths_labels_ds = tf.data.Dataset.zip((paths_ds,labels_ds)) images_labels_ds = paths_labels_ds.shuffle(buffer_size=300000) images_labels_ds = images_labels_ds.map(lambda filename,label : tf.py_function( func=image_read_cv2, inp=[filename,label], Tout=[tf.uint8,tf.int32]), num_parallel_calls=AUTOTUNE) images_labels_ds = images_labels_ds.map(image_resize,num_parallel_calls=AUTOTUNE) images_labels_ds = images_labels_ds.repeat() images_labels_ds = images_labels_ds.batch(BATCH_SIZE) images_labels_ds = images_labels_ds.prefetch(buffer_size = 200) # plt.figure(figsize=(8,8)) # for n,(image,label) in enumerate(images_labels_ds.take(10)): # plt.subplot(2,5,n+1) # plt.imshow(image) # plt.grid(False) # plt.xticks([]) # plt.yticks([]) # plt.xlabel('xxxxxxxxx label') # plt.show() return images_labels_ds def prepare_test_ds(filenames,labels): global BATCH_SIZE paths_ds = tf.data.Dataset.from_tensor_slices(filenames) labels_ds = tf.data.Dataset.from_tensor_slices(labels) images_labels_ds = tf.data.Dataset.zip((paths_ds,labels_ds)) # images_labels_ds = images_labels_ds.shuffle(buffer_size=300000) images_labels_ds = images_labels_ds.map(lambda filename,label : tf.py_function( func=image_read_cv2, inp=[filename,label], Tout=[tf.uint8,tf.int32]), num_parallel_calls=AUTOTUNE) images_labels_ds = images_labels_ds.map(image_resize,num_parallel_calls=AUTOTUNE) images_labels_ds = images_labels_ds.repeat() images_labels_ds = images_labels_ds.batch(BATCH_SIZE) images_labels_ds = images_labels_ds.prefetch(buffer_size = 200) return images_labels_ds train_image_label_ds = prepare_train_ds(train_filenames,train_labels) test_image_label_ds = prepare_test_ds(test_filenames,test_labels) return train_image_label_ds,test_image_label_ds train_image_label_ds,test_image_label_ds = prepare_dataset_trian_test() keras_ds = train_image_label_ds keras_validate_ds = test_image_label_ds InceptionV3 = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet', input_tensor=None, input_shape=(299, 299, 3)) inception_v3 = InceptionV3.get_layer(index = -1).output output = tf.keras.layers.AveragePooling2D((8, 8), strides=(8, 8), name='avg_pool')(inception_v3) output = tf.keras.layers.Flatten(name='flatten')(output) output = tf.keras.layers.Dense(2, activation='softmax', name='predictions')(output) model = tf.keras.models.Model(InceptionV3.input, output) model.trainable = True # for layer in model.layers[:-3]: # layer.trainable = False model.summary() # for x in model.non_trainable_weights: # print(x.name) optimizer = tf.keras.optimizers.SGD(lr = 0.0001, momentum = 0.9, decay = 0.0, nesterov = True) model.compile(loss='binary_crossentropy', optimizer = optimizer, metrics = ['accuracy']) steps_per_epoch = tf.ceil(IMAGE_COUNT_TRAIN/BATCH_SIZE) validation_steps = tf.ceil(IMAGE_COUNT_TEST/BATCH_SIZE) checkpoint_path = 'checkpoint4/cp.ckpt' checkpoint_dir = os.path.dirname(checkpoint_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, save_best_only=True, monitor='val_acc', mode='max', verbose = 1) tb_callback = tf.keras.callbacks.TensorBoard(log_dir='./Graph4', update_freq='batch',histogram_freq=0, write_graph=True, write_images=True) model.load_weights(checkpoint_path) model.fit(keras_ds, epochs=8, steps_per_epoch=steps_per_epoch, # validation_split = 0.2, validation_data = keras_validate_ds, validation_steps = validation_steps, callbacks = [cp_callback,tb_callback] ) print() print() print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>evaluate>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>') model.load_weights(checkpoint_path) loss,acc = model.evaluate(keras_validate_ds,steps=validation_steps) print("Restored model, accuracy: {:5.2f}%".format(100*acc))```