deftrain_for_n(nb_epoch=5000, plt_frq=25,BATCH_SIZE=32): for e in tqdm(range(nb_epoch)): # Make generative images image_batch = X_train[np.random.randint(0,X_train.shape[0],size=BATCH_SIZE),:,:,:] noise_gen = np.random.uniform(0,1,size=[BATCH_SIZE,n_cat]) generated_images = generator.predict(noise_gen) # Train discriminator on generated images X = np.concatenate((image_batch, generated_images)) y = np.zeros([2*BATCH_SIZE,2]) y[0:BATCH_SIZE,1] = 1 y[BATCH_SIZE:,0] = 1 make_trainable(discriminator,True) d_loss = discriminator.train_on_batch(X,y) losses["d"].append(d_loss) # train Generator-Discriminator stack on input noise to non-generated output class noise_tr = np.random.uniform(0,1,size=[BATCH_SIZE,n_cat]) y2 = np.zeros([BATCH_SIZE,2]) y2[:,1] = 1 make_trainable(discriminator,False) g_loss = GAN.train_on_batch(noise_tr, y2 ) losses["g"].append(g_loss) # Updates plots if e%plt_frq==plt_frq-1: plot_loss(losses) plot_gen()
例2
1 2 3 4 5 6 7 8 9 10 11 12 13 14
datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset samplewise_center=False, # set each sample mean to 0 featurewise_std_normalization=False, # divide inputs by std of the dataset samplewise_std_normalization=False, # divide each input by its std zca_whitening=False, # apply ZCA whitening rotation_range=10, # randomly rotate images in the range (degrees, 0 to 180) zoom_range = 0.1, # Randomly zoom image width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) height_shift_range=0.1, # randomly shift images vertically (fraction of total height) horizontal_flip=False, # randomly flip images vertical_flip=False) # randomly flip images