Matplotlib is a Python 2D plotting library which produces publication quality figures in a variety of hardcopy formats and interactive environments across platforms. Matplotlib can be used in Python scripts, the Python and IPython shells, the Jupyter notebook, web application servers, and four graphical user interface toolkits.

显示数据集图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from matplotlib import pyplot as plt
def plot_real(n_ex=16,dim=(4,4), figsize=(10,10) ):
# 在数据集中随机选取n_ex张图片,并存至generated_images中
idx = np.random.randint(0,X_train.shape[0],n_ex)
generated_images = X_train[idx,:,:,:]
# 绘图,4*4显示所选取的图片
plt.figure(figsize=figsize)
for i in range(generated_images.shape[0]):
plt.subplot(dim[0],dim[1],i+1)
img = generated_images[i,:,:,0]
plt.imshow(img, cmap = 'bone')
plt.axis('off')
plt.tight_layout()
plt.show()
plot_real()

绘制loss

1
2
3
4
5
6
7
8
9
losses = {"d":[], "g":[]}
def plot_loss(losses):
display.clear_output(wait=True)
display.display(plt.gcf())
plt.figure(figsize=(10,8))
plt.plot(losses["d"], label='discriminitive loss')
plt.plot(losses["g"], label='generative loss')
plt.legend()
plt.show()