3.14keras训练的断点保存和恢复

摘要

keras 训练大量数据集时,我们不希望因为断点或者别的干扰下,存在训练中断,然后下次又是重头开始训练的这种情况。好在keras和tensorflow一样都是存在训练中的断点保存和恢复功能,本文主要通过代码例程来演示如何使用keras中的断点保存和恢复功能。

  • [x] Edit By Porter, 积水成渊,蛟龙生焉。

第一步:导入ModelCheckpoint模块

导入断点保存和恢复的模块和断点保存的路径。

1
2
3
from keras.callbacks import ModelCheckpoint
# 断点保存的路径
checkpoint_dir = './work/keras_model/checkpoint-best.hdf5'

第二步:创建模型保存的本地文件夹

1
2
3
# 创建断点复训文件夹
if not os.path.exists('./work/keras_model/'):
os.mkdir('./work/keras_model/')

第三步:检测是否存在上次训练保存的模型参数

每次运行程序前,检测本地文件夹里是否存在上次保存的模型训练参数保存的文件夹

1
2
3
4
5
if os.path.exists(checkpoint_dir):
print('INFO:checkpoint exists, Load weights from %s\n'%checkpoint_dir)
model.load_weights(checkpoint_dir) # 加载断点文件夹
else:
print('No checkpoint found')

第四步:创建模型断点保存函数

这个函数作为模型训练中的模型参数保存的文件,我们每次按指定次数,将当前的训练参数保存在本地的文件中。

1
2
3
4
5
6
checkpoint = ModelCheckpoint(checkpoint_dir,
monitor='val_loss',
save_weights_only=True,
verbose=1,
save_best_only=True,
period=1)

第五步:通过开始训练的函数调用第四步中的回调函数

1
2
3
4
5
6
7
8
9
#拟合模型
history = model.fit_generator(
train_generator,
steps_per_epoch=29620//50,#迭代进入下一轮次需要抽取的批次
epochs=30,#数据迭代的轮数
validation_data=val_generator,
validation_steps=30,#验证集用于评估的批次
initial_epoch=27, #再训练时的初始回合数,一般断点后,这个值设置你之前断点结束后的值
callbacks=[checkpoint])# 回调函数,断点的保存
文章目录
  1. 1. 摘要
    1. 1.0.1. 第一步:导入ModelCheckpoint模块
    2. 1.0.2. 第二步:创建模型保存的本地文件夹
    3. 1.0.3. 第三步:检测是否存在上次训练保存的模型参数
    4. 1.0.4. 第四步:创建模型断点保存函数
    5. 1.0.5. 第五步:通过开始训练的函数调用第四步中的回调函数
|