fit与fit_generator

fit

  • Load all data into the graphics memory in one time
  • Trains the model for a given number of epochs (iterations on a dataset)
  • fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

fit_generator

  • Trains the model on data generated batch-by-batch by a Python generator
  • Use yield to return the data to fit_generator
  • fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

关于yield

  • 类似于return的关键字,用于把一个函数变成generator,运行时,需要调用循环
  • 函数执行流程与普通函数类似,但执行到yield时,函数会中断,停止执行,并返回当前迭代值;下次调用时函数继续往下执行,直到再次遇上yield

##如何写generator

  • 伪代码
1
2
3
4
5
6
7
8
9
10
def generator_batch_data_random(x,y,batch_size)
# 逐步提取batch数据到显存,降低对显存的占用
y_len = len(y)
loopcount = y_len//batch_size
while(True):
i=random.shuffle(range(0,loopcount))
for j in enumerate(i)
yield x[ i*batch_size : (i+1)*batch_size ], y[ i*batch_size : (i+1)*batch_size ]

# reference:利用fit_generator最小化显存占用比率
  • 小技巧
    • 在自己写的方法中使用随机(fit方法会默认shuffle=True,而fit_generator需要我们自己随机打乱数据)。
    • 在generator中,需要用while写成死循环,因为每个epoch不会重新调用方法

train_on_batch^1

区别

With fit_generator, you can use a generator for the validation data as well. In general I would recommend using fit_generator, but using train_on_batch works fine too. These methods only exist as for the sake of convenience in different use cases, there is no “correct” method.

Jump to PrePare-Data