深度学习-tensorflow对花的品种进行分类( 二 )


将通过将这些数据集传递给模型来训练模型 。一会儿就好
for image_batch, labels_batch in train_ds:print(image_batch.shape)print(labels_batch.shape)break
(32, 180, 180, 3)
(32,)
是一个形状张量(32,180,180,3) 。这是一批32张形状为的图像(最后一个维度是彩色通道RGB) 。是一个形状(32,)的张量,这些是32幅图像对应的标签 。
可以对和张量调用.numpy()将它们转换为numpy. 。
为性能配置数据集
让我们确保使用缓冲预取,这样就可以在不阻塞I/O的情况下从磁盘生成数据 。这是加载数据时应该使用的两个重要方法 。
AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
标准化的数据
RGB通道值在[0,255]范围内 。这对于神经网络来说并不理想;通常,应该设法使输入值小一些 。在这里,使用一个层来标准化[0,1]范围内的值 。
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))image_batch, labels_batch = next(iter(normalized_ds))first_image = image_batch[0]# Notice the pixels values are now in `[0,1]`.print(np.min(first_image), np.max(first_image))
可以在模型定义中包含该层,这可以简化部署 。我们用第二种方法 。
创建模型
该模型由三个卷积块组成,每个卷积块中有一个最大池层 。有一个完全连接的层,上面有128个单元,由一个relu激活功能激活 。这个模型还没有进行高精度的调整 。
num_classes = 5model = Sequential([layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(num_classes)])
编译模型
对于本教程,选择优化器 。亚当优化和损失 。损失函数 。要查看每个培训阶段的培训和验证准确性,请传递参数 。
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
模型的总结
使用模型的汇总方法查看网络的所有层:
model.summary()

深度学习-tensorflow对花的品种进行分类

文章插图
训练模型
epochs=10history = model.fit(train_ds,validation_data=http://www.kingceram.com/post/val_ds,epochs=epochs)
可视化培训结果
在培训和验证集上创建损失和准确性图 。
acc = history.history['accuracy']val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(8, 8))plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')plt.plot(epochs_range, val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()