关于人工智能:深度学习应用篇元学习15基于度量的元学习SNAILRNPNMN

31次阅读

共计 14002 个字符,预计需要花费 36 分钟才能阅读完成。

深度学习利用篇 - 元学习[15]:基于度量的元学习:SNAIL、RN、PN、MN

1.Simple Neural Attentive Learner(SNAIL)

元学习能够被定义为一种序列到序列的问题,
在现存的办法中,元学习器的瓶颈是如何去排汇异化利用过来的教训。
注意力机制能够容许在历史中精准摘取某段具体的信息。

Simple Neural Attentive Learner (SNAIL)
组合时序卷积和 soft-attention,
前者从过来的教训整合信息,后者准确查找到某些非凡的信息。

1.1 Preliminaries

1.1.1 时序卷积和 soft-attention

时序卷积 (TCN) 是有因果前后关系的,即在下一时间步生成的值仅仅受之前的工夫步影响。
TCN 能够提供更间接,高带宽的传递信息的办法,这容许它们基于一个固定大小的时序内容进行更简单的计算。
然而,随着序列长度的减少,卷积收缩的尺度会随之指数减少,须要的层数也会随之对数减少。
因而这种办法对于之前输出的拜访更粗略,且他们的无限的能力和地位依赖并不适宜元学习器,
因为元学习器应该可能利用增长数量的教训,而不是随着教训的减少,性能会被受限。

soft-attention 能够实现从超长的序列内容中获取精确的非凡信息。
它将上下文作为一种无序的要害值存储,这样就能够基于每个元素的内容进行查问。
然而,地位依赖的不足(因为是无序的)也是一个毛病。

TCN 和 soft-attention 能够实现性能互补:
前者提供高带宽的办法,代价是受限于上下文的大小,后者能够基于不确定的可能无限大的上下文提供精准的提取。
因而,SNAIL 的构建应用二者的组合:应用时序卷积去解决用注意力机制提取过的内容。
通过整合 TCN 和 attention,SNAIL 能够基于它过来的教训产出高带宽的解决办法且不再有教训数量的限度。
通过在多个阶段应用注意力机制,端到端训练的 SNAIL 能够学习从收集到的信息中如何摘取本人须要的信息并学习一个失当的示意。

1.1.2 Meta-Learning

在元学习中每个工作 $\mathcal{T}_{i}$ 都是独立的,
其输出为 $x_{t}$,输入为 $a_{t}$,损失函数是 $\mathcal{L}_{i}\left(x_{t}, a_{t}\right)$,
一个转移散布 $P_{i}\left(x_{t} \mid x_{t-1}, a_{t-1}\right)$,和一个输入长度 $H_i$。
一个元学习器(由 $\theta$ 参数化)建模散布:

$$
\pi\left(a_{t} \mid x_{1}, \ldots, x_{t} ; \theta\right)
$$

给定一个工作的散布 $\mathcal{T}=P\left(\mathcal{T}_{i}\right)$,
元学习器的指标是最小化它的期待损失:

$$
\begin{aligned}
&\min _{\theta} \mathbb{E}_{\mathcal{T}_{i} \sim \mathcal{T}}\left[\sum_{t=0}^{H_{i}} \mathcal{L}_{i}\left(x_{t}, a_{t}\right)\right] \\
&\text {where} x_{t} \sim P_{i}\left(x_{t} \mid x_{t-1}, a_{t-1}\right), a_{t} \sim \pi\left(a_{t} \mid x_{1}, \ldots, x_{t} ; \theta\right)
\end{aligned}
$$

元学习器被训练去针对从 $\mathcal{T}$ 中抽样进去的工作 (或一个 mini-batches 的工作) 优化这个冀望损失。
在测试阶段,元学习器在新工作散布 $\widetilde{\mathcal{T}}=P\left(\widetilde{\mathcal{T}}_{i}\right)$ 上被评估。

1.2 SNAIL

1.2.1 SNAIL 根底构造

两个时序卷积层(橙色)和一个因果关系层(绿色)的组合是 SNAIL 的根底构造,
如图 1 所示。
在监督学习设置中,
SNAIL 接管标注样本 $\left(x_{1}, y_{1}\right), \ldots,\left(x_{t-1}, y_{t-1}\right)$ 和末标注的 $\left(x_{t},-\right)$,
而后基于标注样本对 $y_{t}$ 进行预测。

图 1 SNAIL 根底构造示意图。

1.2.2 Modular Building Blocks

对于构建 SNAIL 应用了两个次要模块:
Dense Block 和 Attention Block。

图 1 SNAIL 中的 Dense Block 和 Attention Block。(a) Dense Block 利用因果一维卷积,而后将输入连贯到输出。TC Block 利用一系列膨胀率呈指数增长的 Dense Block。(b) Attention Block 执行 (因果) 键值查找,并将输入连贯到输出。

Densen Block
用了一个简略的因果一维卷积(空洞卷积),
其中膨胀率 (dilation)为 $R$ 和卷积核数量 $D$([1] 对于所有的试验中设置卷积核的大小为 2),
最初合并后果和输出。
在计算结果的时候应用了一个门激活函数。
具体算法如下:

  1. function DENSENBLOCK (inuts, dilation rate $R$, number of filers $D$):

    1. xf, xg = CausalConv (inputs, $R$, $D$), CausalConv (inputs, $R$, $D$)
    2. activations = tanh (xf) * sigmoid (xg)
    3. return concat (inputs, activations)

TC Block
由一系列 dense block 组成,这些 dense block 的膨胀率 $R$ 呈指数级增长,直到它们的承受域超过所需的序列长度。具体代码实现时,对序列是须要填充的为了放弃序列长度不变。具体算法如下:

  1. function TCBLOCK (inuts, sequence length $T$, number of filers $D$):

    1. for i in $1, \ldots, \left[log_2T\right]$ do

       1. inputs = DenseBlock (inputs, $2^i$, $D$)
    2. return inputs

Attention Block
[1] 中设计成 soft-attention 机制,
公式为:

$$
\mathrm{Attention}(Q, K, V)=\mathrm{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V
$$

  1. function ATTENTIONBLOCK (inuts, key size $K$, value size $V$):

    1. keys, query = affine (inputs, $K$), affine (inputs, $K$)
    2. logits = matmul (query, transpose (keys))
    3. probs = CausallyMaskedSoftmax ($\mathrm{logits} / \sqrt{K}$)
    4. values = affine (inputs, $V$)
    5. read = matmul (probs, values)
    6. return concat (inputs, read)

1.3 SNAIL 分类后果

<center>
表 1 SNAIL 在 Omniglot 上的分类后果。
</center>

Method 5-way 1-shot 5-way 5-shot 20-way 1-shot 20-way 5-shot
Santoro et al. (2016) 82.8 $\%$ 94.9 $\%$
Koch (2015) 97.3 $\%$ 98.4 $\%$ 88.2 $\%$ 97.0 $\%$
Vinyals et al. (2016) 98.1 $\%$ 98.9 $\%$ 93.8 $\%$ 98.5 $\%$
Finn et al. (2017) 98.7 $\pm$ 0.4 $\%$ 99.9 $\pm$ 0.3 $\%$ 95.8 $\pm$ 0.3 $\%$ 98.9 $\pm$ 0.2 $\%$
Snell et al. (2017) 97.4 $\%$ 99.3 $\%$ 96.0 $\%$ 98.9 $\%$
Munkhdalai $\&$ Yu (2017) 98.9 $\%$ 97.0 $\%$
SNAIL 99.07 $\pm$ 0.16 $\%$ 99.78 $\pm$ 0.09 $\%$ 97.64 $\pm$ 0.30 $\%$ 99.36 $\pm$ 0.18 $\%$

<center>
表 1 SNAIL 在 miniImageNet 上的分类后果。
</center>

Method 5-way 1-shot 5-way 5-shot
Vinyals et al. (2016) 43.6 $\%$ 55.3 $\%$
Finn et al. (2017) 48.7 $\pm$ 1.84 $\%$ 63.1 $\pm$ 0.92 $\%$
Ravi $\&$ Larochelle (2017) 43.4 $\pm$ 0.77 $\%$ 60.2 $\pm$ 0.71 $\%$
Snell et al. (2017) 46.61 $\pm$ 0.78 $\%$ 65.77 $\pm$ 0.70 $\%$
Munkhdalai $\&$ Yu (2017) 49.21 $\pm$ 0.96 $\%$
SNAIL 55.71 $\pm$ 0.99 $\%$ 68.88 $\pm$ 0.92 $\%$
  • 参考文献

[1] A Simple Neural Attentive Meta-Learner

2.Relation Network(RN)

Relation Network (RN) 应用有监督度量学习预计样本点之间的间隔,
依据新样本点和过来样本点之间的间隔远近,对新样本点进行分类。

2.1 RN

RN 包含两个组成部分:嵌入模块和关系模块,且两者都是通过有监督学习失去的。
嵌入模块从输出数据中提取特色,关系模块依据特色计算工作之间的间隔,
判断工作之间的相似性,找到过来可借鉴的教训进行加权均匀。
RN 构造如图 1 所示。

图 1 RN 构造。

嵌入模块记为 $f_{\varphi}$,关系模块记为 $g_{\phi}$,
反对集中的样本记为 $\boldsymbol{x}_{i}$,
查问集中的样本记为 $\boldsymbol{x}_{j}$。

  • 将 $\boldsymbol{x}_{i}$ 和 $\boldsymbol{x}_{j}$ 输出 $f_{\varphi}$,
    产生特色映射 $f_{\varphi}\left(\boldsymbol{x}_{i}\right)$
    和 $f_{\varphi}\left(\boldsymbol{x}_{j}\right)$。
  • 通过运算器 $C(.,.)$ 将 $f_{\varphi}\left(\boldsymbol{x}_{i}\right)$
    和 $f_{\varphi}\left(\boldsymbol{x}_{j}\right)$ 联合,
    失去 $C(f_{\varphi}\left(\boldsymbol{x}_{i}\right),f_{\varphi}\left(\boldsymbol{x}_{j}\right))$。
  • 将 $C(f_{\varphi}\left(\boldsymbol{x}_{i}\right),f_{\varphi}\left(\boldsymbol{x}_{j}\right))$ 输出 $g_{\phi}$,
    失去 $[0, 1]$ 范畴内的标量,
    示意 $\boldsymbol{x}_{i}$ 和 $\boldsymbol{x}_{j}$ 之间的相似性,记为关系得分 $r_{i, j}$。
    $\boldsymbol{x}_{i}$ 和 $\boldsymbol{x}_{j}$ 类似度越高,$r_{i, j}$ 越大。

$$
r_{i, j}=g_{\phi}\left(C\left(f_{\varphi}\left(\boldsymbol{x}_{i}\right), f_{\varphi}\left(\boldsymbol{x}_{j}\right)\right)\right), \
i = 1, 2, …, C
$$

2.2 RN 指标函数

$$
\phi, \varphi \leftarrow \underset{\phi, \varphi}{\arg \min} \sum_{i=1}^{m} \sum_{j=1}^{n}\left(r_{i, j}-1\left(\boldsymbol{y}_{i}==\boldsymbol{y}_{j}\right)\right)^{2}
$$

其中,$1\left(\boldsymbol{y}_{i}=\boldsymbol{y}_{j}\right)$ 用来判断 $\boldsymbol{x}_{i}$ 和 $\boldsymbol{x}_{j}$ 是否属于同一类别。
当 $\boldsymbol{y}_{i}=\boldsymbol{y}_{j}$ 时,$1\left(\boldsymbol{y}_{i}==\boldsymbol{y}_{j}\right)=1$,
当 $\boldsymbol{y}_{i} \neq \boldsymbol{y}_{j}$ 时,$1\left(\boldsymbol{y}_{i}==\boldsymbol{y}_{j}\right)=0$。

2.3 RN 网络结构

嵌入模块和关系模块的选取有很多种,包含卷积网络、残差网络等。

图 2 给出了 [1] 中应用的 RN 模型构造。

图 2 RN 模型构造。

2.3.1 嵌入模块构造

  • 每个卷积块别离蕴含 64 个 3 $\times$ 3 滤波器进行卷积,一个归一化层、一个 ReLU 非线性层。
  • 总共有四个卷积块,前两个卷积块蕴含 2 $\times$ 2 的最大池化层,后边两个卷积块没有池化层。

3.2 关系模块构造

  • 有两个卷积块,每个卷积模块中都蕴含 2 $\times$ 2 的最大池化层。
  • 两个全连贯层,第一个全连贯层是 ReLU 非线性变换,最初的全连贯层应用 Sigmoid 非线性变换输入 $r_{i,j}$。

2.4 RN 分类后果

<center>
表 1 RN 在 Omniglot 上的分类后果。
</center>

Model Fine Tune 5-way 1-shot 5-way 5-shot 20-way 1-shot 20-way 5-shot
MANN N 82.8 $\%$ 94.9 $\%$
CONVOLUTIONAL SIAMESE NETS N 96.7 $\%$ 98.4 $\%$ 88.0 $\%$ 96.5 $\%$
CONVOLUTIONAL SIAMESE NETS Y 97.3 $\%$ 98.4 $\%$ 88.1 $\%$ 97.0 $\%$
MATCHING NETS N 98.1 $\%$ 98.9 $\%$ 93.8 $\%$ 98.5 $\%$
MATCHING NETS Y 97.9 $\%$ 98.7 $\%$ 93.5 $\%$ 98.7 $\%$
SIAMESE NETS WITH MEMORY N 98.4 $\%$ 99.6 $\%$ 95.0 $\%$ 98.6 $\%$
NEURAL STATISTICIAN N 98.1 $\%$ 99.5 $\%$ 93.2 $\%$ 98.1 $\%$
META NETS N 99.0 $\%$ 97.0 $\%$
PROTOTYPICAL NETS N 98.8 $\%$ 99.7 $\%$ 96.0 $\%$ 98.9 $\%$
MAML Y 98.7 $\pm$ 0.4 $\%$ 99.9 $\pm$ 0.1 $\%$ 95.8 $\pm$ 0.3 $\%$ 98.9 $\pm$ 0.2 $\%$
RELATION NET N 99.6 $\pm$ 0.2 $\%$ 99.8 $\pm$ 0.1 $\%$ 97.6 $\pm$ 0.2 $\%$ 99.1 $\pm$ 0.1 $\%$

<center>
表 1 RN 在 miniImageNet 上的分类后果。
</center>

Model FT 5-way 1-shot 5-way 5-shot
MATCHING NETS N 43.56 $\pm$ 0.84 $\%$ 55.31 $\pm$ 0.73 $\%$
META NETS N 49.21 $\pm$ 0.96 $\%$
META-LEARN LSTM N 43.44 $\pm$ 0.77 $\%$ 60.60 $\pm$ 0.71 $\%$
MAML Y 48.70 $\pm$ 1.84 $\%$ 63.11 $\pm$ 0.92 $\%$
PROTOTYPICAL NETS N 49.42 $\pm$ 0.78 $\%$ 68.20 $\pm$ 0.66 $\%$
RELATION NET N 50.44 $\pm$ 0.82 $\%$ 65.32 $\pm$ 0.70 $\%$
  • 参考文献

[1] Learning to Compare: Relation Network for Few-Shot Learning

3.Prototypical Network(PN)

Prototypical Network (PN) 利用反对集中每个类别提供的大量样本,
计算它们的嵌入核心,作为每一类样本的原型 (Prototype),
接着基于这些原型学习一个度量空间,
使得新的样本通过计算本身嵌入与这些原型的间隔实现最终的分类。

3.1 PN

在 few-shot 分类工作中,
假如有 $N$ 个标记的样本 $S=\left(x_{1}, y_{1}\right), \ldots,\left(x_{N}, y_{N}\right)$,
其中,$x_{i} \in$ $\mathbb{R}^{D}$ 是 $D$ 维的样本特征向量,
$y \in 1, \ldots, K$ 是相应的标签。
$S_{K}$ 示意第 $k$ 类样本的汇合。

PN 计算每个类的 $M$ 维原型向量 $c_{k} \in \mathbb{R}^{M}$,
计算的函数为 $f_{\phi}: \mathbb{R}^{D} \rightarrow \mathbb{R}^{M}$,
其中 $\phi$ 为可学习参数。
原型向量 $c_{k}$ 即为嵌入空间中该类的所有 反对集样本点的均值向量

$$
c_{k}=\frac{1}{\left|S_{K}\right|} \sum_{\left(x_{i}, y_{i}\right) \in S_{K}} f_{\phi}\left(x_{i}\right)
$$

给定一个间隔函数 $d: \mathbb{R}^{M} \times \mathbb{R}^{M} \rightarrow[0,+\infty)$,
不蕴含任何可训练的参数,
PN 通过在嵌入空间中对间隔进行 softmax 计算,
失去一个针对 $x$ 的样本点的概率分布

$$
p_{\phi}(y=k \mid x)=\frac{\exp \left(-d\left(f_{\phi}(x), c_{k}\right)\right)}{\sum_{k^{\prime}} \exp \left(-d\left(f_{\phi}(x), c_{k^{\prime}}\right)\right)}
$$

新样本点的特色离类别中心点越近,
新样本点属于这个类别的概率越高;
新样本点的特色离类别中心点越远,
新样本点属于这个类别的概率越低。

通过在 SGD 中最小化第 $k$ 类的负对数似然函数 $J(\phi)$ 来推动学习

$$
J(\phi)= \underset{\phi}{\operatorname{argmin}}\left(\sum_{k=1}^{K}-\log \left(p_{\phi}\left(\boldsymbol{y}=k \mid \boldsymbol{x}_{k}\right)\right)\right)
$$

PN 示意图如图 1 所示。

图 1 PN 示意图。

3.2 PN 算法流程

Input: Training set $\mathcal{D}=\left\{\left(\mathbf{x}_{1}, y_{1}\right), \ldots,\left(\mathbf{x}_{N}, y_{N}\right)\right\}$, where each $y_{i} \in\{1, \ldots, K\}$. $\mathcal{D}_{k}$ denotes the subset of $\mathcal{D}$ containing all elements $\left(\mathbf{x}_{i}, y_{i}\right)$ such that $y_{i}=k$.

Output: The loss $J$ for a randomly generated training episode.

  1. select class indices for episode: $V \leftarrow \text {RANDOMSAMPLE}\left(\{1, \ldots, K\}, N_{C}\right)$
  2. for $k$ in $\left\{1, \ldots, N_{C}\right\}$ do

    1. select support examples: $S_{k} \leftarrow \text {RANDOMSAMPLE}\left(\mathcal{D}_{V_{k}}, N_{S}\right)$
    2. select query examples: $Q_{k} \leftarrow \text {RANDOMSAMPLE}\left(\mathcal{D}_{V_{k}} \backslash S_{k}, N_{Q}\right)$
    3. compute prototype from support examples: $c_k \leftarrow \frac{1}{N_{C}} \sum_{\left(\mathbf{x}_{i}, y_{i}\right) \in S_{k}} f_{\phi}\left(\mathbf{x}_{i}\right)$
  3. end for
  4. $J \leftarrow 0$
  5. for $k$ in $\left\{1, \ldots, N_{C}\right\}$ do

    1. for $x, y$ in $Q_{k}$ do
    2. update loss $\left.J \leftarrow J+\frac{1}{N_{C} N_{Q}}\left[d\left(f_{\phi}(\mathbf{x}), \mathbf{c}_{k}\right)\right)+\log \sum_{k^{\prime}} \exp \left(-d\left(f_{\phi}(\mathbf{x}), \mathbf{c}_{k^{\prime}}\right)\right)\right]$
  6. end for
  7. end for

其中,

  • $N$ 是训练集中的样本个数;
  • $K$ 是训练集中的类个数;
  • $N_{C} \leq K$ 是每个 episode 选出的类个数;
  • $N_{S}$ 是每类中 support set 的样本个数;
  • $N_{Q}$ 是每类中 query set 的样本个数;
  • $\mathrm{RANDOMSAMPLE}(S, N)$ 示意从汇合 $\mathrm{S}$ 中随机选出 $\mathrm{N}$ 个元素。

3.3 PN 分类后果

<center>
表 1 PN 在 Omniglot 上的分类后果。
</center>

Model Dist. Fine Tune 5-way 1-shot 5-way 5-shot 20-way 1-shot 20-way 5-shot
MATCHING NETWORKS Cosine N 98.1 $\%$ 98.9 $\%$ 93.8 $\%$ 98.5 $\%$
MATCHING NETWORKS Cosine Y 97.9 $\%$ 98.7 $\%$ 93.5 $\%$ 98.7 $\%$
NEURAL STATISTICIAN N 98.1 $\%$ 99.5 $\%$ 93.2 $\%$ 98.1 $\%$
MAML N 98.7 $\%$ 99.9 $\%$ 95.8 $\%$ 98.9 $\%$
PROTOTYPICAL NETWORKS Euclid. N 98.8 $\%$ 99.7 $\%$ 96.0 $\%$ 98.9 $\%$

<center>
表 1 PN 在 miniImageNet 上的分类后果。
</center>

Model Dist. Fine Tune 5-way 1-shot 5-way 5-shot
BASELINE NEAREST NEIGHBORS Cosine N 28.86 $\pm$ 0.54 $\%$ 49.79 $\pm$ 0.79 $\%$
MATCHING NETWORKS Cosine N 43.40 $\pm$ 0.78 $\%$ 51.09 $\pm$ 0.71 $\%$
MATCHING NETWORKS (FCE) Cosine N 43.56 $\pm$ 0.84 $\%$ 55.31 $\pm$ 0.73 $\%$
META-LEARNER LSTM N 43.44 $\pm$ 0.77 $\%$ 60.60 $\pm$ 0.71 $\%$
MAML N 48.70 $\pm$ 1.84 $\%$ 63.15 $\pm$ 0.91 $\%$
PROTOTYPICAL NETWORKS Euclid. N 49.42 $\pm$ 0.78 $\%$ 68.20 $\pm$ 0.66 $\%$
  • 参考文献

[1] Prototypical Networks for Few-shot Learning

4.Matching Network(MN)

Matching Network (MN)
联合了度量学习 (Metric Learning) 与记忆加强神经网络 (Memory Augment Neural Networks),
并利用注意力机制与记忆机制减速学习,同时提出了 set-to-set 框架,
使得 MN 可能为新类产生正当的测试标签,且不必网络做任何扭转。

4.1 MN

将反对集 $S=\left\{\left(x_{i}, y_{i}\right)\right\}_{i=1}^{k}$
映射到一个分类器 $c_{S}(\hat{x})$,
给定一个测试样本 $\hat{x}$,$c_{S}(\hat{x})$ 定义一个对于输入 $\hat{y}$ 的概率分布,即

$$
S \rightarrow c_{S}\left(\hat{x}\right):=
P\left(\hat{y} \mid \hat{x}, S\right)
$$

其中,$P$ 被网络参数化。
因而,当给定一个新的反对集 $S^{\prime}$ 进行小样本学习时,
只需应用 $P$ 定义的网络来预测每个测试示例 $\hat{x}$ 的适当标签散布
$P\left(\hat{y} \mid \hat{x}, S^{\prime}\right)$ 即可。

4.1.1 注意力机制

模型以最简略的模式计算 $\hat{y}$ 上的概率:

$$
P(\hat{y} \mid \hat{x}, S)=\sum_{i=1}^{k} a\left(\hat{x}, x_{i}\right) y_{i}
$$

上式实质是将一个输出的新类形容为反对集中所有类的一个线性组合,
联合了核密度估计 KDE($a$ 能够看做是一种核密度估计)和 KNN。
其中, $k$ 示意反对集中样本类别数,
$a\left(\hat{x}, x_{i}\right)$ 是注意力机制,
相似 attention 模型中的核函数,
用来度量 $\hat{x}$ 和训练样本 $x_{i}$ 的匹配度。

$a$ 的计算基于新样本数据与反对集中的样本数据的嵌入示意的余弦类似度以及 softmax 函数:

$$
a\left(\hat{x}, x_{i}\right)=\frac{e^{c\left(f(\hat{x}), g\left(x_{i}\right)\right)}}{\sum_{j=1}^{k} e^{c\left(f(\hat{x}), g\left(x_{j}\right)\right)}}
$$

其中,$c(\cdot)$ 示意余弦类似度,
$f$ 与 $g$ 示意施加在测试样本与训练样本上的嵌入函数 (Embedding Function)。

如果注意力机制是 $X \times X$ 上的核,
则上式相似于核密度估计器。
如果选取适合的间隔度量以及适当的常数,
从而使得从 $x_{i}$ 到 $\hat{x}$ 的注意力机制为 0,
则上式等价于 KNN。

图 1 是 MN 的网络结构示意图。

图 1 MN 示意图。

4.1.2 Full Context Embeddings

为了加强样本嵌入的匹配度,
[1] 提出了 Full Context Embeeding (FCE) 办法:
反对集中每个样本的嵌入应该是互相独立的,
而新样本的嵌入应该受反对集样本数据分布的调控,
其嵌入过程须要放在整个反对集环境下进行,
因而 [1] 采纳带有注意力的 LSTM 网络对新样本进行嵌入。

在对余弦注意力定义时,
每个已知标签的输出 $x_i$ 通过 CNN 后的 embedding,
因而 $g(x_i)$ 是独立的,前后没有关系,
而后与 $f\left(\hat{x}\right)$ 进行一一比照,
并没有思考到输出工作 $S$ 扭转 embedding $\hat{x}$ 的形式,
而 $f(\cdot)$ 应该是受 $g(S)$ 影响的。
为了实现这个性能,[1] 采纳了双向 LSTM。

在通过嵌入函数 $f$ 和 $g$ 解决后,
输入再次通过循环神经网络进一步增强 context 和个体之间的关系。

$$
f\left(\hat{x},S\right)=\mathrm{attLSTM}\left(f’\left(\hat{x}\right),g(S),K\right)
$$

其中,$S$ 是相干的上下文,$K$ 为网络的 timesteps。

因而,通过 $k$ 步后的状态为:

$$
\begin{aligned}
& \hat{h}_{k}, c_{k} =\operatorname{LSTM}\left(f^{\prime}(\hat{x}),\left[h_{k-1}, r_{k-1}\right], c_{k-1}\right) \\
& h_{k} =\hat{h}_{k}+f^{\prime}(\hat{x}) \\
& r_{k-1} =\sum_{i=1}^{|S|} a\left(h_{k-1}, g\left(x_{i}\right)\right) g\left(x_{i}\right) \\
& a\left(h_{k-1}, g\left(x_{i}\right)\right) =e^{h_{k-1}^{T} g\left(x_{i}\right)} / \sum_{j=1}^{|S|} e^{h_{k-1}^{T} g\left(x_{j}\right)}
\end{aligned}
$$

4.2 网络结构

特征提取器可采纳常见的 VGG 或 Inception 网络,
[1] 设计了一种简略的四级网络结构用于图像分类工作的特征提取,
每级网络由一个 64 通道的 3 $\times$ 3 卷积层,一个批规范化层,
一个 ReLU 激活层和一个 2 $\times$ 2 的最大池化层形成。
而后将最初一层输入的特色输出到 LSTM 网络中失去最终的特色映射
$f\left(\hat{x},S\right)$ 和 $g\left({x_i},S\right)$。

4.3 损失函数

$$
\theta=\arg \max _{\theta} E_{L \sim T}\left[E_{S \sim L, B \sim L}\left[\sum_{(x, y) \in B} \log P_{\theta}(y \mid x, S)\right]\right]
$$

4.4 MN 算法流程

  • 将工作 $S$ 中所有图片 $x_i$(假如有 $K$ 个)和指标图片 $\hat{x}$(假如有 1 个)
    全副通过 CNN 网络,取得它们的浅层变量示意。
  • 将($K+1$ 个)浅层变量全副输出到 BiLSTM 中,取得 $K+1$ 个输入,
    而后应用余弦间隔判断前 $K$ 个输入中每个输入与最初一个输入之间的类似度。
  • 依据计算出来的类似度,依照工作 $S$ 中的标签信息 $y_1, y_2, \ldots, y_K$
    求解指标图片 $\hat{x}$ 的类别标签 $\hat{y}$。

4.5 MN 分类后果

<center>
表 1 MN 在 Omniglot 上的分类后果。
</center>

Model Matching Fn Fine Tune 5-way 1-shot 5-way 5-shot 20-way 1-shot 20-way 5-shot
PIXELS Cosine N 41.7 $\%$ 63.2 $\%$ 26.7 $\%$ 42.6 $\%$
BASELINE CLASSIFIER Cosine N 80.0 $\%$ 95.0 $\%$ 69.5 $\%$ 89.1 $\%$
BASELINE CLASSIFIER Cosine Y 82.3 $\%$ 98.4 $\%$ 70.6 $\%$ 92.0 $\%$
BASELINE CLASSIFIER Softmax Y 86.0 $\%$ 97.6 $\%$ 72.9 $\%$ 92.3 $\%$
MANN (NO CNOV) Cosine N 82.8 $\%$ 94.9 $\%$
CONVOLUTIONAL SIAMESE NET Cosine Y 96.7 $\%$ 98.4 $\%$ 88.0 $\%$ 96.5 $\%$
CONVOLUTIONAL SIAMESE NET Cosine Y 97.3 $\%$ 98.4 $\%$ 88.1 $\%$ 97.0 $\%$
MATCHING NETS Cosine N 98.1 $\%$ 98.9 $\%$ 93.8 $\%$ 98.5 $\%$
MATCHING NETS Cosine Y 97.9 $\%$ 98.7 $\%$ 93.5 $\%$ 98.7 $\%$

<center>
表 1 MN 在 miniImageNet 上的分类后果。
</center>

Model Matching Fn Fine Tune 5-way 1-shot 5-way 5-shot
PIXELS Cosine N 23.0 $\%$ 26.6 $\%$
BASELINE CLASSIFIER Cosine N 36.6 $\%$ 46.0 $\%$
BASELINE CLASSIFIER Cosine Y 36.2 $\%$ 52.2 $\%$
BASELINE CLASSIFIER Cosine Y 38.4 $\%$ 51.2 $\%$
MATCHING NETS Cosine N 41.2 $\%$ 56.2 $\%$
MATCHING NETS Cosine Y 42.4 $\%$ 58.0 $\%$
MATCHING NETS Cosine (FCE) N 44.2 $\%$ 57.0 $\%$
MATCHING NETS Cosine (FCE) Y 46.6 $\%$ 60.0 $\%$

4.6 翻新点

  • 采纳匹配的模式实现小样本分类工作,
    引入最近邻算法的思维解决了深度学习算法在小样本的条件下无奈充沛优化参数而导致的过拟合问题,
    且利用带有注意力机制和记忆模块的网络解决了一般最近邻算法适度依赖度量函数的问题,
    将样本的特色信息映射到更高维度更形象的特色空间中。
  • one-shot learning 的训练策略,一个训练任务中蕴含反对集和 Batch 样本。

4.7 算法评估

  • MN 受到非参量化算法的限度,
    随着反对集 $S$ 的增长,每次迭代的计算量也会随之快速增长,导致计算速度升高。
  • 在测试时必须提供蕴含指标样本类别在内的反对集,
    否则它只能从反对集所蕴含的类别中抉择最为靠近的一个输入其类别,而不能输入正确的类别。
  • 参考文献

[1] Matching Networks for One Shot Learning

更多优质内容请关注公号:汀丶人工智能

正文完
 0