关于人工智能:使用PyTorch进行小样本学习的图像分类

60次阅读

共计 6596 个字符,预计需要花费 17 分钟才能阅读完成。

近年来,基于深度学习的模型在指标检测和图像识别等工作中表现出色。像 ImageNet 这样具备挑战性的图像分类数据集,蕴含 1000 种不同的对象分类,当初一些模型曾经超过了人类程度上。然而这些模型依赖于监督训练流程,标记训练数据的可用性对它们有重大影响,并且模型可能检测到的类别也仅限于它们承受训练的类。

因为在训练过程中没有足够的标记图像用于所有类,这些模型在事实环境中可能不太有用。并且咱们心愿的模型可能辨认它在训练期间没有见到过的类,因为简直不可能在所有潜在对象的图像上进行训练。咱们将从几个样本中学习的问题被称为“少样本学习 Few-Shot learning”。

什么是小样本学习?

少样本学习是机器学习的一个子畛域。它波及到在只有多数训练样本和监督数据的状况下对新数据进行分类。只需大量的训练样本,咱们创立的模型就能够相当好地执行。

思考以下场景: 在医疗畛域,对于一些不常见的疾病,可能没有足够的 x 光图像用于训练。对于这样的场景,构建一个小样本学习分类器是完满的解决方案。

小样本的变动

一般来说,钻研人员确定了四种类型:

  1. N-Shot Learning (NSL)
  2. Few-Shot Learning (FSL)
  3. One-Shot Learning (OSL)
  4. Zero-Shot Learning (ZSL)

当咱们议论 FSL 时,咱们通常指的是 N-way-K-Shot 分类。N 代表类别数,K 代表每个类中要训练的样本数。所以 N -Shot Learning 被视为比所有其余概念更宽泛的概念。能够说 Few-Shot、One-Shot 和 Zero-Shot 是 NSL 的子畛域。而零样本学习旨在在没有任何训练示例的状况下对看不见的类进行分类。

在 One-Shot Learning 中,每个类只有一个样本。Few-Shot 每个类有 2 到 5 个样本,也就是说 Few-Shot 是更灵便的 One-Shot Learning 版本。

小样本学习办法

通常,在解决 Few Shot Learning 问题时应思考两种办法:

数据级办法 (DLA)

这个策略非常简单,如果没有足够的数据来创立实体模型并避免欠拟合和过拟合,那么就应该增加更多数据。正因为如此,许多 FSL 问题都能够通过利用来更大大的根底数据集的更多数据来解决。根本数据集的显着特色是它短少形成咱们对 Few-Shot 挑战的反对集的类。例如,如果咱们想要对某种鸟类进行分类,则根底数据集可能蕴含许多其余鸟类的图片。

参数级办法 (PLA)

从参数级别的角度来看,Few-Shot Learning 样本绝对容易过拟合,因为它们通常具备大的高维空间。限度参数空间、应用正则化和应用适当的损失函数将有助于解决这个问题。大量的训练样本将被模型泛化。

通过将模型疏导到广大的参数空间能够进步性能。因为不足训练数据,失常的优化办法可能无奈产生精确的后果。

因为下面的起因,训练咱们的模型以发现通过参数空间的最佳门路,产生最佳的预测后果。这种办法被称为元学习。

小样本学习图像分类算法

有 4 种比拟常见的小样本学习的办法:

与模型无关的元学习 Model-Agnostic Meta-Learning

基于梯度的元学习 (GBML) 准则是 MAML 的根底。在 GBML 中,元学习者通过根底模型训练和学习所有工作示意的共享特色来取得先前的教训。每次有新工作要学习时,元学习器都会利用其现有教训和新工作提供的最大量的新训练数据进行微调训练。

个别状况下,如果咱们随机初始化参数通过几次更新算法将不会收敛到良好的性能。MAML 试图解决这个问题。MAML 只需几个梯度步骤并且保障没有适度拟合的前提下,为元参数学习器提供了牢靠的初始化,这样能够对新工作进行最佳疾速学习。

步骤如下:

  1. 元学习者在每个分集(episode)开始时创立本人的正本 C,
  2. C 在这一分集上进行训练(在 base-model 的帮忙下),
  3. C 对查问集进行预测,
  4. 从这些预测中计算出的损失用于更新 C,
  5. 这种状况始终继续到实现所有分集的训练。

这种技术的最大劣势在于,它被认为与元学习算法的抉择无关。因而 MAML 办法被宽泛用于许多须要疾速适应的机器学习算法,尤其是深度神经网络。

匹配网络 Matching Networks

为解决 FSL 问题而创立的第一个度量学习办法是匹配网络 (MN)。

当应用匹配网络办法解决 Few-Shot Learning 问题时须要一个大的根底数据集。。

将该数据集分为几个分集之后,对于每一分集,匹配网络进行以下操作:

  • 来自反对集和查问集的每个图像都被馈送到一个 CNN,该 CNN 为它们输入特色的嵌入
  • 查问图像应用反对集训练的模型失去嵌入特色的余弦间隔,通过 softmax 进行分类
  • 分类后果的穿插熵损失通过 CNN 反向流传更新特色嵌入模型

匹配网络能够通过这种形式学习构建图像嵌入。MN 可能应用这种办法对照片进行分类,并且无需任何非凡的类别先验常识。他只有简略地比拟类的几个实例就能够了。

因为类别因分集而异,因而匹配网络会计算对类别辨别很重要的图片属性(特色)。而当应用规范分类时,算法会抉择每个类别独有的特色。

原型网络 Prototypical Networks

与匹配网络相似的是原型网络(PN)。它通过一些轻微的变动来进步算法的性能。PN 比 MN 获得了更好的后果,但它们训练过程实质上是雷同的,只是比拟了来自反对集的一些查问图片嵌入,然而 原型网络提供了不同的策略。

咱们须要在 PN 中创立类的原型:通过对类中图像的嵌入进行均匀而创立的类的嵌入。而后仅应用这些类原型来比拟查问图像嵌入。当用于单样本学习问题时,它可与匹配网络相媲美。

关系网络 Relation Network

关系网络能够说继承了所有下面提到办法的钻研的后果。RN 是基于 PN 思维的但蕴含了显著的算法改良。

该办法应用的间隔函数是可学习的,而不是像以前钻研的当时定义它。关系模块位于嵌入模块之上,嵌入模块是从输出图像计算嵌入和类原型的局部。

可训练的关系模块(间隔函数)输出是查问图像的嵌入与每个类的原型,输入为每个分类匹配的关系分数。关系分数通过 Softmax 失去一个预测。

应用 Open-AI Clip 进行零样本学习

CLIP(Contrastive Language-Image Pre-Training)是一个在各种(图像、文本)对上训练的神经网络。它无需间接针对工作进行优化,就能够为给定的图像来预测最相干的文本片段(相似于 GPT-2 和 3 的零样本的性能)。

CLIP 在 ImageNet“零样本”上能够达到原始 ResNet50 的性能,而且须要不应用任何标记示例,它克服了计算机视觉中的几个次要挑战,上面咱们应用 Pytorch 来实现一个简略的分类模型。

引入包

 ! pip install ftfy regex tqdm
 ! pip install git+https://github.com/openai/CLIP.gitimport numpy as np
 import torch
 from pkg_resources import packaging
 
 print("Torch version:", torch.__version__)

加载模型

 import clipclip.available_models() # it will list the names of available CLIP modelsmodel, preprocess = clip.load("ViT-B/32")
 model.cuda().eval()
 input_resolution = model.visual.input_resolution
 context_length = model.context_length
 vocab_size = model.vocab_size
 
 print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
 print("Input resolution:", input_resolution)
 print("Context length:", context_length)
 print("Vocab size:", vocab_size)

图像预处理

咱们将向模型输出 8 个示例图像及其文本形容,并比拟对应特色之间的相似性。

分词器不辨别大小写,咱们能够自在地给出任何适合的文本形容。

 import os
 import skimage
 import IPython.display
 import matplotlib.pyplot as plt
 from PIL import Image
 import numpy as np
 
 from collections import OrderedDict
 import torch
 
 %matplotlib inline
 %config InlineBackend.figure_format = 'retina'
 
 # images in skimage to use and their textual descriptions
 descriptions = {
     "page": "a page of text about segmentation",
     "chelsea": "a facial photo of a tabby cat",
     "astronaut": "a portrait of an astronaut with the American flag",
     "rocket": "a rocket standing on a launchpad",
     "motorcycle_right": "a red motorcycle standing in a garage",
     "camera": "a person looking at a camera on a tripod",
     "horse": "a black-and-white silhouette of a horse", 
     "coffee": "a cup of coffee on a saucer"
 }original_images = []
 images = []
 texts = []
 plt.figure(figsize=(16, 5))
 
 for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
     name = os.path.splitext(filename)[0]
     if name not in descriptions:
         continue
 
     image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
   
     plt.subplot(2, 4, len(images) + 1)
     plt.imshow(image)
     plt.title(f"{filename}\n{descriptions[name]}")
     plt.xticks([])
     plt.yticks([])
 
     original_images.append(image)
     images.append(preprocess(image))
     texts.append(descriptions[name])
 
 plt.tight_layout()

后果的可视化如下:

咱们对图像进行规范化,对每个文本输出进行标记,并运行模型的正流传取得图像和文本的特色。

 image_input = torch.tensor(np.stack(images)).cuda()
 text_tokens = clip.tokenize(["This is" + desc for desc in texts]).cuda()
 
 with torch.no_grad():
     image_features = model.encode_image(image_input).float()
     text_features = model.encode_text(text_tokens).float()

咱们将特色归一化,并计算每一对的点积,进行余弦类似度计算

 image_features /= image_features.norm(dim=-1, keepdim=True)
 text_features /= text_features.norm(dim=-1, keepdim=True)
 similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
 
 count = len(descriptions)
 
 plt.figure(figsize=(20, 14))
 plt.imshow(similarity, vmin=0.1, vmax=0.3)
 # plt.colorbar()
 plt.yticks(range(count), texts, fontsize=18)
 plt.xticks([])
 for i, image in enumerate(original_images):
     plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
 for x in range(similarity.shape[1]):
     for y in range(similarity.shape[0]):
         plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
         
 for side in ["left", "top", "right", "bottom"]:
   plt.gca().spines[side].set_visible(False)
 
 plt.xlim([-0.5, count - 0.5])
 plt.ylim([count + 0.5, -2])
 
 plt.title("Cosine similarity between text and image features", size=20)

零样本的图像分类

 from torchvision.datasets import CIFAR100
 cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
 text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
 text_tokens = clip.tokenize(text_descriptions).cuda()
 with torch.no_grad():
     text_features = model.encode_text(text_tokens).float()
     text_features /= text_features.norm(dim=-1, keepdim=True)
     
 text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
 top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
 plt.figure(figsize=(16, 16))
 for i, image in enumerate(original_images):
     plt.subplot(4, 4, 2 * i + 1)
     plt.imshow(image)
     plt.axis("off")
 
     plt.subplot(4, 4, 2 * i + 2)
     y = np.arange(top_probs.shape[-1])
     plt.grid()
     plt.barh(y, top_probs[i])
     plt.gca().invert_yaxis()
     plt.gca().set_axisbelow(True)
     plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
     plt.xlabel("probability")
     
 plt.subplots_adjust(wspace=0.5)
 plt.show()

能够看到,分类的成果还是十分好的

https://avoid.overfit.cn/post/91d2e9fa40ca40208b9c7112cd825d3b

作者:Aryan Jadon

正文完
 0