摘要
keras 训练大量数据集时,我们不希望因为断点或者别的干扰下,存在训练中断,然后下次又是重头开始训练的这种情况。好在keras和tensorflow一样都是存在训练中的断点保存和恢复功能,本文主要通过代码例程来演示如何使用keras中的断点保存和恢复功能。
- [x] Edit By Porter, 积水成渊,蛟龙生焉。
第一步:导入ModelCheckpoint模块
导入断点保存和恢复的模块和断点保存的路径。
1 2 3
| from keras.callbacks import ModelCheckpoint
checkpoint_dir = './work/keras_model/checkpoint-best.hdf5'
|
第二步:创建模型保存的本地文件夹
1 2 3
| if not os.path.exists('./work/keras_model/'): os.mkdir('./work/keras_model/')
|
第三步:检测是否存在上次训练保存的模型参数
每次运行程序前,检测本地文件夹里是否存在上次保存的模型训练参数保存的文件夹
1 2 3 4 5
| if os.path.exists(checkpoint_dir): print('INFO:checkpoint exists, Load weights from %s\n'%checkpoint_dir) model.load_weights(checkpoint_dir) else: print('No checkpoint found')
|
第四步:创建模型断点保存函数
这个函数作为模型训练中的模型参数保存的文件,我们每次按指定次数,将当前的训练参数保存在本地的文件中。
1 2 3 4 5 6
| checkpoint = ModelCheckpoint(checkpoint_dir, monitor='val_loss', save_weights_only=True, verbose=1, save_best_only=True, period=1)
|
第五步:通过开始训练的函数调用第四步中的回调函数
1 2 3 4 5 6 7 8 9
| history = model.fit_generator( train_generator, steps_per_epoch=29620//50, epochs=30, validation_data=val_generator, validation_steps=30, initial_epoch=27, callbacks=[checkpoint])
|