订阅
纠错
加入自媒体

使用数据增强从头开始训练卷积神经网络(CNN)

2022-11-24 14:21
磐创AI
关注

介绍

该文致力于处理神经网络中的过度拟合。

过度拟合将是你主要担心的问题,因为你仅使用 2000 个数据样本训练模型。存在一些有助于克服过度拟合的方法,即 dropout 和权重衰减(L2 正则化)。

我们将讨论数据增强,这是计算机视觉独有的,在使用深度学习模型解释图像时,数据增强在任何地方都会用到。

数据增强

学习示例不足会阻止你训练可以泛化到新数据的模型,从而导致过度拟合。如果你有无限的数据,你的模型将暴露于当前数据分布的所有特征,从而防止过度拟合。

通过增加具有不同随机变化的样本来产生逼真的图像,数据增强使用现有的训练样本来生成更多的训练数据。

你的模型不应在训练期间两次查看同一图像。这使模型更加通用并暴露了数据的其他特征。

Keras 可以通过使用ImageDataGenerator函数定义要应用于图像的各种随机变换来实现这一点。

让我们从一个插图开始。

####-----data augmentation configuration via ImageDataGenerator-------####

datagen = ImageDataGenerator(

rotation=40,

width_shift=0.2,

height_shift=0.2,

shear=0.2,

zoom=0.2,

horizontal_flip=True,

fill_mode='nearest')

让我们快速回顾一下这段代码:

· rotation:这是图像随机旋转的范围。它的容量在(0-180)度之间。

· width_shift 和 height_shift:范围(作为总宽度或高度的一部分),在其中垂直或水平随机翻转图片。

· shear:用于随机应用剪切变换。

· zoom:用于随机缩放图像。

· Horizontal_flip :用于随机水平翻转一半图像

· fill_mode:是用于填充新生成的像素的方法,这些像素可能在旋转或宽度/高度变化后出现。

显示增强图像

####-----Let's display some randomly augmented training images-------####

from keras.preprocessing import image

fnames = [os.path.join(train_cats_dir, fname) for fname in os.listdir(train_cats_dir)]

img_path = fnames[3]

img = image.load_img(img_path, target_size=(150, 150))

x = image.img_to_array(img)

x = x.reshape((1,) + x.shape)

i = 0

for batch in datagen.flow(x, batch_size=1):

plt.figure(i)

imgplot = plt.imshow(image.array_to_img(batch[0]))

i += 1

if i % 4 == 0:

  break

plt.show()

图:使用数据增强生成猫图片

如果你使用数据增强设置训练新网络,网络将永远不会收到两次相同的输入。

然而,因为它只接收来自少量原始照片的输入,这些输入仍然是高度相关的;你只能重新混合已经存在的信息。

因此,这可能不足以消除过度拟合。在密集链接分类器之前,你应该在算法中包含一个 Dropout 层,以进一步对抗过度拟合。

实时数据增强应用

1. 医疗保健

管理数据集不是医学成像应用的解决方案,因为获取大量经过专业标记的样本需要很长时间和金钱。

通过增强设计的网络必须比类似 X 射线图片中的预测变化更可靠和真实。但是,我们可以通过使用数据增强来增加后续插图中的数据数量。

图:X 射线图像中的数据增强

2. 自动驾驶汽车

自动驾驶汽车是一个不同的使用主题,其中数据增强是有益的。

例如,CARLA旨在在物理模拟中产生灵活性和真实感。CARLA 旨在促进自动驾驶系统的结果、指导和验证。它基于虚幻引擎 4,并提供了一个完整的模拟器环境,用于在安全的环境中测试自动驾驶技术。

当数据稀缺成为问题时,使用强化学习技术创建的模拟环境可以帮助人工智能系统的训练和测试。对模拟环境进行建模以创建真实场景的能力为数据增强开辟了一个充满可能性的世界。

从头开始定义 CNN 模型

####------Defining CNN, including dropout--------####

model = models.Sequential()

model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))

model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(64, (3, 3), activation='relu'))

model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(128, (3, 3), activation='relu'))

model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(128, (3, 3), activation='relu'))

model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Flatten())

model.add(layers.Dropout(0.5))

model.add(layers.Dense(512, activation='relu'))

model.add(layers.Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc'])

让我们使用数据增强和损失函数来训练网络。

####-------Train CNN using data-augmentation--------#####

train_datagen = ImageDataGenerator(rescale=1./255, rotation=40, width_shift=0.2, height_shift=0.2, shear=0.2, zoom=0.2, horizontal_flip=True,)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(train_dir, target_size=(150, 150), batch_size=32, class_mode='binary')

validation_generator = test_datagen.flow_from_directory(validation_dir, target_size=(150, 150), batch_size=32, class_mode='binary')

history = model.fit_generator(train_generator, steps_per_epoch=100, epochs=100, validation_data=validation_generator, validation_steps=50)

####-------Save the model--------#####

model.save('cats_and_dogs_small_2.h5')

由于数据增强和丢失,模型不再过度拟合。因为训练曲线和验证曲线彼此接近。有了这个准确度,你就超过了非正则化模型 15%,达到了 82%。让我们绘制曲线。

在训练期间显示损失曲线和准确度

通过使用其他正则化方法和微调网络参数(例如每个卷积层的过滤器数量或网络中的层数),你可以实现更高的准确度,高达 86% 或 87%。

但是,由于你要处理的数据很少,因此仅通过从头开始训练自己的 CNN 来达到更高的水平将是一项挑战。

你必须采用预训练模型作为进一步的步骤,以提高你在此挑战中的准确性。

结论

1. 训练数据的质量、数量和上下文本质会显着影响深度学习模型的准确性。但开发深度学习模型的最大问题之一是缺乏数据。

2. 在生产使用方法中获取此类数据可能既昂贵又耗时。公司使用数据增强这一低成本且高效的技术来更快地开发高精度 AI 模型,并减少对收集和准备训练实例的依赖。

3. 本文解释了我们如何使用数据增强技术来训练我们的模型。当收集大量数据具有挑战性时,会使用数据增强。正如博客中所讨论的,医疗保健和无人驾驶汽车是使用这种方法的两个最著名的领域。

       原文标题 : 使用数据增强从头开始训练卷积神经网络(CNN)

声明: 本文由入驻维科号的作者撰写,观点仅代表作者本人,不代表OFweek立场。如有侵权或其他问题,请联系举报。

发表评论

0条评论,0人参与

请输入评论内容...

请输入评论/评论长度6~500个字

您提交的评论过于频繁,请输入验证码继续

暂无评论

暂无评论

人工智能 猎头职位 更多
扫码关注公众号
OFweek人工智能网
获取更多精彩内容
文章纠错
x
*文字标题:
*纠错内容:
联系邮箱:
*验 证 码:

粤公网安备 44030502002758号