深度学习-tensorflow对花的品种进行分类( 三 )


从图中可以看到,训练精度和验证精度相差很大,模型在验证集上仅实现了约60%的准确性 。
看看哪里出了问题,尝试提高模型的整体性能 。
过度拟合
在上面的图中,训练精度随时间线性增加,而验证精度在训练过程中停滞在60%左右 。此外,训练和验证准确性之间的差异是明显的——这是过度拟合的迹象 。
当训练样本数量很少时,模型有时会从训练样本的噪声或不需要的细节中学习,这在一定程度上会对模型在新样本上的性能产生负面影响 。这种现象被称为过拟合 。这意味着模型在新的数据集中泛化时会有困难 。
在训练过程中有多种方法可以对抗过拟合 。在本教程中,将使用数据增强并将添加到您的模型中 。
数据增加
过拟合通常发生在只有少量训练例子的情况下 。数据增强采用的方法是从现有的例子中生成额外的训练数据,通过使用随机变换来增强它们,产生看起来可信的图像 。这有助于向数据的更多方面公开模型,并更好地一般化 。
使用来自tf.keras...的层来实现数据增强 。它们可以像其他层一样被包含在你的模型中,并在GPU上运行 。
data_augmentation = keras.Sequential([layers.experimental.preprocessing.RandomFlip("horizontal", input_shape=(img_height, img_width,3)),layers.experimental.preprocessing.RandomRotation(0.1),layers.experimental.preprocessing.RandomZoom(0.1),])
让我们通过多次对同一幅图像应用数据增强来可视化几个增强示例:
plt.figure(figsize=(10, 10))for images, _ in train_ds.take(1):for i in range(9):augmented_images = data_augmentation(images)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_images[0].numpy().astype("uint8"))plt.axis("off")
稍后将使用数据增强来训练模型 。
另一种减少过拟合的技术是在网络中引入,这是一种正则化的形式 。
当将应用到一个层时,它会在训练过程中从该层随机退出一些输出单位(通过将激活设置为零) 。以一个小数作为输入值,形式有0.1、0.2、0.4等 。这意味着从应用层中随机去掉10%、20%或40%的输出单元 。
让我们用层创建一个新的神经网络 。退出,然后用增强图像训练它 。
model = Sequential([data_augmentation,layers.experimental.preprocessing.Rescaling(1./255),layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Dropout(0.2),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(num_classes)])
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
model.summary()
epochs = 15history = model.fit(train_ds,validation_data=http://www.kingceram.com/post/val_ds,epochs=epochs)
可视化培训结果
应用数据增强和后,过拟合减少,训练和验证精度更接近 。
acc = history.history['accuracy']val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(8, 8))plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')plt.plot(epochs_range, val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()