共计 1271 个字符,预计需要花费 4 分钟才能阅读完成。
本次教程的目标是率领大家学会根本的花朵图像分类
首先咱们来介绍下数据集,该数据集有 5 种花,一共有 3670 张图片,别离是 daisy、dandelion、roses、sunflowers、tulips,数据寄存构造如下所示
咱们能够展现下 roses 的几张图片
接下来咱们须要加载数据集,而后对数据集进行划分,最初造成训练集、验证集、测试集,留神此处的验证集是从训练集切分进去的,比例是 8:2
对数据进行摸索的时候,咱们发现原始的像素值是 0 -255,为了模型训练更稳固以及更容易收敛,咱们须要标准化数据集,一般来说就是把像素值缩放到 0 -1,能够用上面的 layer 来实现
normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
为了使训练的时候 I / O 不成为瓶颈,咱们能够进行如下设置
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
下一步就是模型搭建,而后对模型进行训练
num_classes = 5
model = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.Rescaling(1./255),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(num_classes)
])
model.compile(
optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_ds,
validation_data=val_ds,
epochs=3
)
从上图的训练记录能够发现,该模型处于欠拟合状态,咱们能够通过多训练几轮来解决这个问题,而且为了疾速试验,咱们这里用了一个非常简单的模型,咱们能够通过更换更强的模型,来晋升模型的体现
代码链接: https://codechina.csdn.net/cs…
正文完