Keras | 常用Callbacks及范例
Introduce
在使用Sequential
或 Model
类型的 .fit()
或.fit_generator
方法训练时,可以设置callback
关键字参数,完成一系列回调功能,如:
1 | hist=model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0,callbacks=[ckpt, csv_logger, lr]) |
History
https://www.cnblogs.com/chendai21/p/8137601.html
1 | keras.callbacks.History() |
该回调函数被自动启用到每一个 Keras 模型,由 fit
方法返回。
History.history
是一个记录了连续迭代的训练/验证(如果存在)损失值和评估值的字典。
代码内容为训练历史可视化的一个示例。
1 | import matplotlib.pyplot as plt |
ModelCheckpoint
https://www.jianshu.com/p/0711f9e54dd2
1 | keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1) |
在每个训练期之后保存模型。
filepath
可以包括命名格式选项,可以由 epoch
的值和 logs
的键来填充。如果 filepath
是 weights.{epoch:02d}-{val_loss:.2f}.hdf5
, 那么模型被保存的的文件名就会有训练轮数和验证损失。
例1:DnCNN-keras-master
1 | ckpt = ModelCheckpoint(save_dir+'/model_{epoch:02d}.h5',monitor='val_loss',verbose=0,period=args.save_every) |
LearningRateScheduler
1 | keras.callbacks.LearningRateScheduler(schedule, verbose=0) |
学习速率定时器。
例1:DnCNN-keras-master
1 | def step_decay(epoch): |
ReduceLROnPlateau
1 | keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0) |
当标准评估停止提升时,降低学习速率。
当学习停止时,模型总是会受益于降低 2-10 倍的学习速率。
例1:Keras Document
1 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,patience=5, min_lr=0.001) |
TensorBoard
https://www.jianshu.com/p/aa910632be2f
1 | keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, batch_size=32, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, embeddings_data=None, update_freq='epoch') |
这个回调函数为 Tensorboard 编写一个日志, 这样你可以可视化测试和训练的标准评估的动态图像, 也可以可视化模型中不同层的激活值直方图。
1 | tensorboard --logdir=/full_path_to_your_logs |
CSVLogger
1 | keras.callbacks.CSVLogger(filename, separator=',', append=False) |
把epoch结果保存到csv 文件的回调函数。
例1:Keras Document
1 | csv_logger = CSVLogger('training.log') |
例2:DnCNN-keras-master
1 | csv_logger = CSVLogger(save_dir+'/log.csv', append=True, separator=',') |
LambdaCallback
1 | keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None, on_batch_begin=None, on_batch_end=None, on_train_begin=None, on_train_end=None) |
在训练进行中创建简单,自定义的回调函数的回调函数。
例1:Keras Document
1 | # 在每一个批开始时,打印出批数。 |