目录
keras学习记录——训练resnet模型对cifar10分类
前言
一、数据准备
二、构建网络
三、网络训练
四、训练结果
总结
最近在对resnet中的add操作做量化处理,所以先将resnet对cifar10分类训练记录下来。
提示:以下是本篇文章正文内容,下面案例可供参考
如果需要将便签转变为one-hot编码,可以使用如下:
label_train = keras.utils.to_categorical(label_train,10) label_test = keras.utils.to_categorical(label_test, 10)
如果调用时太慢,可以下载cifar-10-python.tar.gz,网上有很多资源,这里就不放链接了,之后将cifar10.load_data()函数中下述路径改为本地路径即可。
origin = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'代码如下:
def resnet_block(input, num_filters=16, kernel_size=3, strides=1, activation='relu'): # x = Conv2D(num_filters, kernel_size=kernel_size,strides=strides,padding='same')(input) x = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(input) x = BatchNormalization()(x) x = Activation('relu')(x) x = Conv2D(num_filters, kernel_size=kernel_size, strides=1,padding='same')(x) x = BatchNormalization()(x) if strides==1: shortcut = input else: shortcut = Conv2D(num_filters, kernel_size=kernel_size,strides=strides,padding='same')(input) # x = myAdd(activation='relu')([x, shortcut]) x = keras.layers.add([x, shortcut]) x = Activation('relu')(x) return x结构如下:
2. resnet
代码如下:
def resnet(input_shape): inputs = Input(shape=input_shape) # 1 x = Conv2D(16, activation='relu', kernel_size=3, strides=1,padding='same')(inputs) x = BatchNormalization()(x) print('layer1,xshape:', x.shape) # 2-7 for i in range(6): x = resnet(x, num_filters=16, kernel_size=3, strides=1, activation='relu') # out: 32*32*16 # 8-13 for i in range(6): if i == 0: x = resnet(x, num_filters=32, kernel_size=3, strides=2, activation='relu') else: x = resnet(x, num_filters=32, kernel_size=3, strides=1, activation='relu') # out: 16*16*32 # 14-19 for i in range(6): if i == 0: x = resnet(x, num_filters=64, kernel_size=3, strides=2, activation='relu') else: x = resnet(x, num_filters=64, kernel_size=3, strides=1, activation='relu') # out: 8*8*64 x = AveragePooling2D(pool_size=2)(x) # out: 4*4*64 y = Flatten()(x) # out:1024 outputs = Dense(10, kernel_initializer='he_normal')(y) # outputs = myDense(10, kernel_initializer='glorot_normal', use_batchnormal = False)(y) outputs = Activation('softmax')(outputs) model = Model(inputs=inputs, outputs=outputs) return model代码如下:
model = resnet((32,32,3)) model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(), metrics=['accuracy']) model.summary() plot_model(model, to_file='./model.png', show_shapes=True) # checkpoint = ModelCheckpoint(filepath='saved_models/'+config['mode']+'-'+str(config['epoch'])+'-'+str(config['batch_size'])+'-{val_accuracy:.4f}'+'.h5',monitor='val_accuracy', # verbose=1,save_best_only=True) csv_logger = CSVLogger('saved_models/'+config['mode']+'-'+str(config['epoch'])+'-'+str(config['batch_size'])+'.log', append=False) lr_reducer = ReduceLROnPlateau(monitor='val_accuracy', factor=0.2, patience=5, mode='max', min_lr=1e-3) # model.fit(data_train, label_train,batch_size=64,epochs=1, validation_data=(data_test,label_test), verbose=1, callbacks=callbacks) aug = ImageDataGenerator(width_shift_range=0.2,height_shift_range=0.2,horizontal_flip=True,zoom_range=0.2) aug.fit(data_train) gen = aug.flow(data_train, label_train, batch_size=config['batch_size']) model.fit_generator(generator=gen,epochs=config['epoch'],validation_data=(data_test,label_test), callbacks=[lr_reducer, csv_logger], verbose=1)
由于时间关系,并未将网络训练完全。
学习到了关于keras中训练的小技巧,如:earlystop等,并且学会了如何构建残差神经网络。