深度学习利用篇 - 元学习[14]:基于优化的元学习 -MAML 模型、LEO 模型、Reptile 模型
1.Model-Agnostic Meta-Learning
Model-Agnostic Meta-Learning (MAML):
与模型无关的元学习,可兼容于任何一种采纳梯度降落算法的模型。
MAML 通过大量的数据寻找一个适合的初始值范畴,从而扭转梯度降落的方向,
找到对工作更加敏感的初始参数,
使得模型可能在无限的数据集上疾速拟合,并取得一个不错的成果。
该办法能够用于回归、分类以及强化学习。
该模型的 Paddle 实现请参考链接:PaddleRec 版本
1.1 MAML
MAML 是典型的双层优化结构,其内层和外层的优化形式如下:
1.1.1 MAML 内层优化形式
内层优化波及到基学习器,从工作散布 $p(T)$ 中随机采样第 $i$ 个工作 $T_{i}$。工作 $T_{i}$ 上,基学习器的指标函数是:
$$
\min _{\phi} L_{T_{i}}\left(f_{\phi}\right)
$$
其中,$f_{\phi}$ 是基学习器,$\phi$ 是基学习器参数,$L_{T_{i}}\left(f_{\phi}\right)$ 是基学习器在 $T_{i}$ 上的损失。更新基学习器参数:
$$
\theta_{i}^{N}=\theta_{i}^{N-1}-\alpha\left[\nabla_{\phi}
L_{T_{i}}\left(f_{\phi}\right)\right]_{\phi=\theta_{i}^{N-1}}
$$
其中,$\theta$ 是元学习器提供给基学习器的参数初始值 $\phi=\theta$,在工作 $T_{i}$ 上更新 $N$ 后 $\phi=\theta_{i}^{N-1}$.
1.1.2 MAML 外层优化形式
外层优化波及到元学习器,将 $\theta_{i}^{N}$ 反馈给元学匀器,此时元指标函数是:
$$
\min _{\theta} \sum_{T_{i}\sim p(T)} L_{T_{i}}\left(f_{\theta_{i}^{N}}\right)
$$
元指标函数是所有工作上验证集损失和。更新元学习器参数:
$$
\theta \leftarrow \theta-\beta \sum_{T_{i} \sim p(T)} \nabla_{\theta}\left[L_{T_{i}}\left(f_{\phi}\right)\right]_{\phi=\theta_{i}^{N}}
$$
1.2 MAML 算法流程
- randomly initialize $\theta$
while not done do:
- sample batch of tasks $T_i \sim p(T)$
for all $T_i$ do:
- evaluate $\nabla_{\phi}L_{T_{i}}\left(f_{\phi}\right)$ with respect to K examples
- compute adapted parameters with gradient descent: $\theta_{i}^{N}=\theta_{i}^{N-1} -\alpha\left[\nabla_{\phi}L_{T_{i}}\left(f_{\phi}\right)\right]_{\phi=\theta_{i}^{N-1}} $
- end for
- update $\theta \leftarrow \theta-\beta \sum_{T_{i} \sim p(T)} \nabla_{\theta}\left[L_{T_{i}}\left(f_{\phi}\right)\right]_{\phi=\theta_{i}^{N}} $
- end while
MAML 中执行了两次梯度降落 (gradient by gradient),别离作用在基学习器和元学习器上。图 1 给出了 MAML 中特定工作参数 $\theta_{i}^{*}$ 和元级参数 $\theta$ 的更新过程。
<center>
图 1 MAML 示意图。灰色线示意特定工作所产生的梯度值(方向);彩色线示意元级参数抉择更新的方向(彩色线方向是几个特定工作产生方向的平均值);虚线代表疾速适应,不同的方向代表不同工作更新的方向。
</center>
1.3 MAML 模型构造
MAML 是一种与模型无关的元学习办法,能够实用于任何基于梯度优化的模型构造。
基准模型:4 modules with a 3 $\times$ 3 convolutions and 64 filters,
followed by batch normalization,
a ReLU nonlinearity,
and 2 $\times$ 2 max-pooling。
1.4 MAML 分类后果
<center>
表 1 MAML 在 Omniglot 上的分类后果。
</center>
Method | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | 20-way 5-shot |
---|---|---|---|---|
MANN, no conv (Santoro et al., 2016) | 82.8 $\%$ | 94.9 $\%$ | — | — |
MAML, no conv | 89.7 $\pm$ 1.1 $\%$ | 97.5 $\pm$ 0.6 $\%$ | — | — |
Siamese nets (Koch, 2015) | 97.3 $\%$ | 98.4 $\%$ | 88.2 $\%$ | 97.0 $\%$ |
matching nets (Vinyals et al., 2016) | 98.1 $\%$ | 98.9 $\%$ | 93.8 $\%$ | 98.5 $\%$ |
neural statistician (Edwards & Storkey, 2017) | 98.1 $\%$ | 99.5 $\%$ | 93.2 $\%$ | 98.1 $\%$ |
memory mod. (Kaiser et al., 2017) | 98.4 $\%$ | 99.6 $\%$ | 95.0 $\%$ | 98.6 $\%$ |
MAML | 98.7 $\pm$ 0.4 $\%$ | 99.9 $\pm$ 0.1 $\%$ | 95.8 $\pm$ 0.3 $\%$ | 98.9 $\pm$ 0.2 $\%$ |
<center>
表 1 MAML 在 miniImageNet 上的分类后果。
</center>
Method | 5-way 1-shot | 5-way 5-shot |
---|---|---|
fine-tuning baseline | 28.86 $\pm$ 0.54 $\%$ | 49.79 $\pm$ 0.79 $\%$ |
nearest neighbor baseline | 41.08 $\pm$ 0.70 $\%$ | 51.04 $\pm$ 0.65 $\%$ |
matching nets (Vinyals et al., 2016) | 43.56 $\pm$ 0.84 $\%$ | 55.31 $\pm$ 0.73 $\%$ |
meta-learner LSTM (Ravi & Larochelle, 2017) | 43.44 $\pm$ 0.77 $\%$ | 60.60 $\pm$ 0.71 $\%$ |
MAML, first order approx. | 48.07 $\pm$ 1.75 $\%$ | 63.15 $\pm$ 0.91 $\%$ |
MAML | 48.70 $\pm$ 1.84 $\%$ | 63.11 $\pm$ 0.92 $\%$ |
1.5 MAML 的优缺点
长处
- 实用于任何基于梯度优化的模型构造。
- 双层优化结构,晋升模型精度和泛化能力,防止过拟合。
毛病
- 存在二阶导数计算
1.6 对 MAML 的探讨
- 每个工作上的基学习器必须是一样的,对于差异很大的工作,最切合工作的基学习器可能会变动,那么就不能用 MAML 来解决这类问题。
- MAML 实用于所有基于随机梯度算法求解的基学习器,这意味着参数都是间断的,无奈思考离散的参数。对于差异较大的工作,往往须要更新网络结构。应用 MAML 无奈实现这样的构造更新。
- MAML 应用的损失函数都是可求导的,这样能力应用随机梯度算法来疾速优化求解,损失函数中不能有不可求导的奇怪点,否则会导致优化求解不稳固。
- MAML 中思考的新工作都是类似的工作,所以没有对工作进行分类,也没有计算工作之间的间隔度量。对每一类工作独自更新其参数初始值,每一类工作的参数初始值不同,这些在 MAML 中都没有思考。
- 参考文献
[1] Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.
2.Latent Embedding Optimization
Latent Embedding Optimization (LEO) 学习模型参数的低维潜在嵌入,并在这个低维潜在空间中执行基于优化的元学习,将基于梯度的自适应过程与模型参数的根底高维空间拆散。
2.1 LEO
在元学习器中,应用 SGD 最小化工作验证集损失函数,
使得模型的泛化能力最大化,计算元参数,元学习器将元参数输出根底学习器,
继而,根底学习器最小化工作训练集损失函数,疾速给出工作上的预测后果。
LEO 构造如图 1 所示。
图 1 LEO 结构图。$D^{\mathrm{tr}}$ 是工作 $\varepsilon$ 的 support set,
$D^{\mathrm{val}}$ 是工作 $\varepsilon$ 的 query set,
$z$ 是通过编码器计算的 $N$ 个类别的类别特色,$f_{\theta}$ 是基学习器,
$\theta$ 是基学习器参数,
$L^{\mathrm{tr}}=f_{\theta}\left(D^{\mathrm{tr}}\right)$, $L^{\mathrm{val}}=f_{\theta}\left(D^{\mathrm{val}}\right)$。
LEO 包含根底学习器和元学习器,还包含编码器和解码器。
在根底学习器中,编码器将高维输出数据映射成特征向量,
解码器将输出数据的特征向量映射成输出数据属于各个类别的概率值,
根底学习器应用元学习器提供的元参数进行参数更新,给出数据标注的预测后果。
元学习器为根底学习器的编码器和解码器提供元参数,
元参数包含特征提取模型的参数、编码器的参数、解码器的参数等,
通过最小化所有工作上的泛化误差,更新元参数。
2.2 根底学习器
编码器和解码器都在根底学习器中,用于计算输出数据属于每个类别的概率值,
进而对输出数据进行分类。
元学习器提供编码器和解码器中的参数,根底学习器疾速的应用编码器和解码器计算输出数据的分类。
工作训练实现后,根底学习器将每个类别数据的特征向量和工作 $\varepsilon$ 的根底学习器参数 $\boldsymbol{\theta}_{\varepsilon}$ 输出元学习器,
元学习器应用这些信息更新元参数。
2.2.1 编码器
编码器模型包含两个次要局部:编码器和关系网络。
编码器 $g_{\phi_{e}}$,其中 $\phi_{e}$ 是编码器的可训练参数,
其性能是将第 $n$ 个类别的输出数据映射成第 $n$ 个类别的特征向量。
关系网络 $g_{\phi_{r}}$,其中 $\phi_{r}$ 是关系网络的可训练参数,
其性能是计算特色之间的间隔。
第 $n$ 个类别的输出数据的特色记为 $z_{n}$。
对于输出数据,首先,应用编码器 $g_{\phi_{e}}$ 对属于第 $n$ 个类别的输出数据进行特征提取;
而后,应用关系网络 $g_{\phi_r}$ 计算特色之间的间隔,
综合思考训练集中所有样本点之间的间隔,计算这些间隔的平均值和离散水平;
第 $n$ 个类别输出数据的特色 $z_{n}$ 遵从高斯分布,
且高斯分布的冀望是这些间隔的平均值,高斯分布的方差是这些间隔的离散水平,
具体的计算公式如下:
$$
\begin{aligned}
&\mu_{n}^{e}, \sigma_{n}^{e}=\frac{1}{N K^{2}} \sum_{k_{n}=1}^{K} \sum_{m=1}^{N} \sum_{k_{m}=1}^{K} g_{\phi_{r}}\left[g_{\phi_{e}}\left(x_{n}^{k_{n}}\right), g_{\phi_{e}}\left(x_{m}^{k_{m}}\right)\right] \\
&z_{n} \sim q\left(z_{n} \mid D_{n}^{\mathrm{tr}}\right)=N\left\{\mu_{n}^{e}, \operatorname{diag}\left(\sigma_{n}^{e}\right)^{2}\right\}
\end{aligned}
$$
其中,$N$ 是类别总数,$K$ 是每个类别的图片总数,
${D}_{n}^{\mathrm{tr}}$ 是第 $n$ 个类别的训练数据集。
对于每个类别的输出数据,每个类别下有 $K$ 张图片,
计算这 $K$ 张图片和所有已知图片之间的间隔。
总共有 $N$ 个类别,通过编码器的计算,造成所有类别的特色,
记为 $z=\left(z_{1}, \cdots, z_{N}\right)$。
2.2.2 解码器
解码器 $g_{\phi_{d}}$,其中 $\phi_{d}$ 是解码器的可训练参数,
其性能是将每个类别输出数据的特征向量 $z_{n}$
映射成属于每个类别的概率值 $\boldsymbol{w}_{n}$:
$$
\begin{aligned}
&\mu_{n}^{d}, \sigma_{n}^{d}=g_{\phi_{d}}\left(z_{n}\right) \\
&w_{n} \sim q\left(w \mid z_{n}\right)=N\left\{\mu_{n}^{d}, \operatorname{diag}\left(\sigma_{n}^{d}\right)^{2}\right\}
\end{aligned}
$$
其中,工作 $\varepsilon$ 的根底学习器参数记为 $\theta_{\varepsilon}$,
根底学习器参数由属于每个类别的概率值组成,
记为 $\theta_{\varepsilon}=\left(w_{1}, w_{2}, \cdots, w_{N}\right)$,
根底学习器参数 $\boldsymbol{w}_{n}$ 指的是输出数据属于第 $n$ 个类别的概率值,
$g_{\phi_{d}}$ 是从特征向量到根底学习器参数的映射。
<center>
图 2 LEO 根底学习器工作原理图。
</center>
2.2.3 根底学习器更新过程
在根底学习器中,工作 $\varepsilon$ 的穿插熵损失函数是:
$$
L_{\varepsilon}^{\mathrm{tr}}\left(f_{\theta_{\varepsilon}}\right)=\sum_{(x, y) \in D_{\varepsilon}^{\mathrm{tr}}}\left[-w_{y} \boldsymbol{x}+\log \sum_{j=1}^{N} \mathrm{e}^{w_{j} x}\right]
$$
其中,$(x, y)$ 是工作 $\varepsilon$ 训练集 $D_{\varepsilon}^{\mathrm{tr}}$ 中的样本点,$f_{\theta_{\varepsilon}}$ 是工作 $\varepsilon$ 的根底学习器,
最小化工作 $\varepsilon$ 的损失函数更新工作专属参数 $\theta_{\varepsilon}$。
在解码器模型中,工作专属参数为 $w_{n} \sim q\left(w \mid z_{n}\right)$,
更新工作专属参数 $\theta_{\varepsilon}$ 意味着更新特征向量 $z_{n}$:
$$
z_{n}^{\prime}=z_{n}-\alpha \nabla_{z_{n}} L_{\varepsilon}^{t r}\left(f_{\theta_{\varepsilon}}\right),
$$
其中,$\boldsymbol{z}_{n}^{\prime}$ 是更新后的特征向量,
对应的是更新后的工作专属参数 $\boldsymbol{\theta}_{\varepsilon}^{\prime}$。
根底学习器应用 $\theta_{\varepsilon}^{\prime}$ 来预测工作验证集数据的标注,
将工作 $\varepsilon$ 的验证集 $\mathrm{D}_{\varepsilon}^{\mathrm{val}}$
损失函数 $L_{\varepsilon}^{\mathrm{val}}\left(f_{\theta_{\varepsilon}^{\prime}}\right)$、
更新后的特征向量 $z_{n}^{\prime}$、
更新后的工作专属参数 $\theta_{\varepsilon}^{\prime}$ 输出元学习器,
在元学习器中更新元参数。
2.3 元学习器更新过程
在元学习器中,最小化所有工作 $\varepsilon$ 的验证集的损失函数的求和,
最小化工作上的模型泛化误差:
$$
\min _{\phi_{e}, \phi_{r}, \phi_{d}} \sum_{\varepsilon}\left[L_{\varepsilon}^{\mathrm{val}}\left(f_{\theta_{\varepsilon}^{\prime}}\right)+\beta D_{\mathrm{KL}}\left\{q\left(z_{n} \mid {D}_{n}^{\mathrm{tr}}\right) \| p\left(z_{n}\right)\right\}+\gamma\left\|s\left(\boldsymbol{z}_{n}^{\prime}\right)-\boldsymbol{z}_{n}\right\|_{2}^{2}\right]+R
$$
其中,$L_{\varepsilon}^{\mathrm{val}}\left(f_{\theta_{\varepsilon}^{\prime}}\right)$ 是工作 $\varepsilon$ 验证集的损失函数,
掂量了根底学习器模型的泛化误差,损失函数越小,模型的泛化能力越好。
$p\left(z_{n}\right)=N(0, I)$ 是高斯分布,$D_{\mathrm{KL}}\left\{q\left(z_{n} \mid {D}_{n}^{\mathrm{tr}}\right) \| p\left(z_{n}\right)\right\}$ 是近似后验散布 $q\left(z_{n} \mid D_{n}^{\text {tr}}\right)$ 与先验散布 $p\left(z_{n}\right)$ 之间的 KL 间隔 (KL-Divergence),
最小化 $\mathrm{KL}$ 间隔可使后验散布 $q\left(z_{n} \mid {D}_{n}^{\text {tr}}\right)$ 的预计尽可能精确。
最小化间隔 $\left\|s\left(z_{n}^{\prime}\right)-z_{n}\right\|$ 使得参数初始值 $z_{n}$ 和训练实现后的参数更新值 $z_{n}^{\prime}$ 间隔最小,
使得参数初始值和参数最终值更靠近。
$R$ 是正则项, 用于调控元参数的复杂程度,避免出现过拟合,正则项 $R$ 的计算公式如下:
$$
R=\lambda_{1}\left(\left\|\phi_{e}\right\|_{2}^{2}+\left\|\phi_{r}\right\|_{2}^{2}+\left\|\phi_{d}\right\|_{2}^{2}\right)+\lambda_{2}\left\|C_{d}-\mathbb{I}\right\|_{2}
$$
其中,$\left\|\phi_{r}\right\|_{2}^{2}$ 指的是调控元参数的个数和大小,
${C}_{d}$ 是参数 $\phi_{d}$ 的行和行之间的相关性矩阵,
超参数 $\lambda_{1},\lambda_{2}>0$,
$\left\|C_{d}-\mathbb{I}\right\|_{2}$ 使得 $C_{d}$ 靠近单位矩阵,
使得参数 $\phi_{d}$ 的行和行之间的相关性不能太大,
每个类别的特征向量之间的相关性不能太大,
属于每个类别的概率值之间的相关性也不能太大,分类要尽量精确。
2.4 LEO 算法流程
LEO 算法流程
- randomly initialize $\phi_{e}, \phi_{r}, \phi_{d}$
- let $\phi=\left\{\phi_{e}, \phi_{r}, \phi_{d}, \alpha\right\}$
while not converged do:
for number of tasks in batch do:
- sample task instance $\mathcal{T}_{i} \sim \mathcal{S}^{t r}$
- let $\left(\mathcal{D}^{t r}, \mathcal{D}^{v a l}\right)=\mathcal{T}_{i}$
- encode $\mathcal{D}^{t r}$ to z using $g_{\phi_{e}}$ and $g_{\phi_{r}}$
- decode $\mathbf{z}$ to initial params $\theta_{i}$ using $g_{\phi_{d}}$
- initialize $\mathbf{z}^{\prime}=\mathbf{z}, \theta_{i}^{\prime}=\theta_{i}$
for number of adaptation steps do:
- compute training loss $\mathcal{L}_{\mathcal{T}_{i}}^{t r}\left(f_{\theta_{i}^{\prime}}\right)$
- perform gradient step w.r.t. $\mathbf{z}^{\prime}$:
- $\mathbf{z}^{\prime} \leftarrow \mathbf{z}^{\prime}-\alpha \nabla_{\mathbf{z}^{\prime}} \mathcal{L}_{\mathcal{T}_{i}}^{t r}\left(f_{\theta_{i}^{\prime}}\right)$
decode $\mathbf{z}^{\prime}$ to obtain $\theta_{i}^{\prime}$ using $g_{\phi_{d}}$
- end for
- compute validation loss $\mathcal{L}_{\mathcal{T}_{i}}^{v a l}\left(f_{\theta_{i}^{\prime}}\right)$
- end for
- perform gradient step w.r.t $\phi$:$\phi \leftarrow \phi-\eta \nabla_{\phi} \sum_{\mathcal{T}_{i}} \mathcal{L}_{\mathcal{T}_{i}}^{v a l}\left(f_{\theta_{i}^{\prime}}\right)$
- end while
(1) 初始化元参数:编码器参数 $\phi_{e}$、关系网络参数 $\phi_{r}$、解码器参数 $\phi_{d}$,
在元学习器中更新的元参数包含 $\phi=\left\{\phi_e, \phi_r,\phi_d \right\}$。
(2) 应用片段式训练模式,
随机抽取工作 $\varepsilon$, ${D}_{\varepsilon}^{\mathrm{tr}}$ 是工作 $\varepsilon$ 的训练集,
${D}_{\varepsilon}^{\mathrm{val}}$ 是工作 $\varepsilon$ 的验证集。
(3) 应用编码器 $g_{\phi_{e}}$ 和关系网络 $g_{\phi_{r}}$ 将工作 $\varepsilon$ 的训练集 $D_{\varepsilon}^{\mathrm{tr}}$ 编码成特征向量 $z$,
应用 解码器 $g_{\phi_{d}}$ 从特征向量映射到工作 $\varepsilon$ 的根底学习器参数 ${\theta}_{\varepsilon}$,
根底学习器参数指的是输出数据属于每个类别的概率值向量;
计算工作 $\varepsilon$ 的训练集的损失函数 $L_{\varepsilon}^{\mathrm{tr}}\left(f_{\theta_{\varepsilon}}\right)$,
最小化工作 $\varepsilon$ 的损失函数,更新每个类别的特征向量:
$$
z_{n}^{\prime}=z_{n}-\alpha \nabla_{z_{n}} L_{\varepsilon}^{\mathrm{tr}}\left(f_{\theta_{\varepsilon}}\right)
$$
应用解码器 $g_{\phi_{d}}$ 从更新后的特征向量映射到更新后的工作 $\varepsilon$ 的根底学习器参数 ${\theta}_{\varepsilon}^{\prime}$;
计算工作 $\varepsilon$ 的验证集的损失函数 $L_{\varepsilon}^{\text {val}}\left(f_{\theta_{s}^{\prime}}\right)$;
根底学习器将更新后的参数和验证集损失函数值输出元学习器。
(4) 更新元参数, $\phi \leftarrow \phi-\eta \nabla_{\phi} \sum_{\varepsilon} L_{\varepsilon}^{\text {val}}\left(f_{\theta_{\varepsilon}^{\prime}}\right)$,
最小化所有工作 $\varepsilon$ 的验证集的损失和,
将更新后的元参数输人根底学习器,持续解决新的分类工作。
2.5 LEO 模型构造
LEO 是一种与模型无关的元学习,[1] 中给出的各局部模型构造及参数如表 1 所示。
<center>
表 1 LEO 各局部模型构造及参数。
</center>
Part of the model | Architecture | Hiddenlayer | Shape of the output |
---|---|---|---|
Inference model ($f_{\theta}$) | 3-layer MLP with ReLU | 40 | (12, 5, 1) |
Encoder | 3-layer MLP with ReLU | 16 | (12, 5, 16) |
Relation Network | 3-layer MLP with ReLU | 32 | (12, $2\times 16$) |
Decoder | 3-layer MLP with ReLU | 32 | (12, $2\times 1761$) |
2.6 LEO 分类后果
<center>
表 1 LEO 在 miniImageNet 上的分类后果。
</center>
Model | 5-way 1-shot | 5-way 5-shot |
---|---|---|
Matching networks (Vinyals et al., 2016) | 43.56 $\pm$ 0.84 $\%$ | 55.31 $\pm$ 0.73 $\%$ |
Meta-learner LSTM (Ravi & Larochelle, 2017) | 43.44 $\pm$ 0.77 $\%$ | 60.60 $\pm$ 0.71 $\%$ |
MAML (Finn et al., 2017) | 48.70 $\pm$ 1.84 $\%$ | 63.11 $\pm$ 0.92 $\%$ |
LLAMA (Grant et al., 2018) | 49.40 $\pm$ 1.83 $\%$ | — |
REPTILE (Nichol & Schulman, 2018) | 49.97 $\pm$ 0.32 $\%$ | 65.99 $\pm$ 0.58 $\%$ |
PLATIPUS (Finn et al., 2018) | 50.13 $\pm$ 1.86 $\%$ | — |
Meta-SGD (our features) | 54.24 $\pm$ 0.03 $\%$ | 70.86 $\pm$ 0.04 $\%$ |
SNAIL (Mishra et al., 2018) | 55.71 $\pm$ 0.99 $\%$ | 68.88 $\pm$ 0.92 $\%$ |
(Gidaris & Komodakis, 2018) | 56.20 $\pm$ 0.86 $\%$ | 73.00 $\pm$ 0.64 $\%$ |
(Bauer et al., 2017) | 56.30 $\pm$ 0.40 $\%$ | 73.90 $\pm$ 0.30 $\%$ |
(Munkhdalai et al., 2017) | 57.10 $\pm$ 0.70 $\%$ | 70.04 $\pm$ 0.63 $\%$ |
DEML+Meta-SGD (Zhou et al., 2018) | 58.49 $\pm$ 0.91 $\%$ | 71.28 $\pm$ 0.69 $\%$ |
TADAM (Oreshkin et al., 2018) | 58.50 $\pm$ 0.30 $\%$ | 76.70 $\pm$ 0.30 $\%$ |
(Qiao et al., 2017) | 59.60 $\pm$ 0.41 $\%$ | 73.74 $\pm$ 0.19 $\%$ |
LEO | 61.76 $\pm$ 0.08 $\%$ | 77.59 $\pm$ 0.12 $\%$ |
<center>
表 1 LEO 在 tieredImageNet 上的分类后果。
</center>
Model | 5-way 1-shot | 5-way 5-shot |
---|---|---|
MAML (deeper net, evaluated in Liu et al. (2018)) | 51.67 $\pm$ 1.81 $\%$ | 70.30 $\pm$ 0.08 $\%$ |
Prototypical Nets (Ren et al., 2018) | 53.31 $\pm$ 0.89 $\%$ | 72.69 $\pm$ 0.74 $\%$ |
Relation Net (evaluated in Liu et al. (2018)) | 54.48 $\pm$ 0.93 $\%$ | 71.32 $\pm$ 0.78 $\%$ |
Transductive Prop. Nets (Liu et al., 2018) | 57.41 $\pm$ 0.94 $\%$ | 71.55 $\pm$ 0.74 $\%$ |
Meta-SGD (our features) | 62.95 $\pm$ 0.03 $\%$ | 79.34 $\pm$ 0.06 $\%$ |
LEO | 66.33 $\pm$ 0.05 $\%$ | 81.44 $\pm$ 0.09 $\%$ |
2.7 LEO 的长处
- 新工作的初始参数以训练数据为条件,这使得工作特定的适应终点成为可能。
通过将关系网络联合到编码器中,该初始化能够更好地思考所有输出数据之间的联结关系。 - 通过在低维潜在空间中进行优化,该办法能够更无效地适应模型的行为。
此外,通过容许该过程是随机的,能够表白在多数数据状态中存在的不确定性和模糊性。 - 参考文献
[1] Meta-Learning with Latent Embedding Optimization
3.Reptile
Reptil 是 MAML 的特例、近似和简化,次要解决 MAML 元学习器中呈现的高阶导数问题。
因而,Reptil 同样学习网络参数的初始值,并且实用于任何基于梯度的模型构造。
在 MAML 的元学习器中,应用了求导数的算式来更新参数初始值,
导致在计算中呈现了工作损失函数的二阶导数。
在 Reptile 的元学习器中,参数初始值更新时,
间接应用了工作上的参数估计值和参数初始值之间的差,
来近似损失函数对参数初始值的导数,进行参数初始值的更新,从而不会呈现工作损失函数的二阶导数。
Peptile 有两个版本:Serial Version 和 Batched Version,两者的差别如下:
3.1 Serial Version Reptile
单次更新的 Reptile,每次训练完一个工作的基学习器,就更新一次元学习器中的参数初始值。
(1) 工作上的基学习器记为 $f_{\phi}$,其中 $\phi$ 是基学习器中可训练的参数,
$\theta$ 是元学习器提供给基学习器的参数初始值。
在工作 $T_{i}$ 上,基学习器的损失函数是 $L_{T_{i}}\left(f_{\phi}\right)$,
基学习器中的参数通过 $N$ 次迭代更新失去参数估计值:
$$
\theta_{i}^{N}=\operatorname{SGD}\left(L_{T_{i}}, {\theta}, {N}\right)
$$
(2) 更新元学习器中的参数初始值:
$$
\theta \leftarrow \theta+\varepsilon\left(\theta_{i}^{N}-\theta\right)
$$
Serial Version Reptile 算法流程
- initialize $\theta$, the vector of initial parameters
for iteration=1, 2, … do:
- sample task $T_i$, corresponding to loss $L_{T_i}$ on weight vectors $\theta$
- compute $\theta_{i}^{N}=\operatorname{SGD}\left(L_{T_{i}}, {\theta}, {N}\right)$
- update $\theta \leftarrow \theta+\varepsilon\left(\theta_{i}^{N}-\theta\right)$
- end for
3.2 Batched Version Reptile
批次更新的 Reptile,每次训练完多个工作的基学习器之后,才更新一次元学习器中的参数初始值。
(1) 在多个工作上训练基学习器,每个工作从参数初始值开始,迭代更新 $N$ 次,失去参数估计值。
(2) 更新元学习器中的参数初始值:
$$
\theta \leftarrow \theta+\varepsilon \frac{1}{n} \sum_{i=1}^{n}\left(\theta_{i}^{N}-\theta\right)
$$
其中,$n$ 是指每次训练完 $n$ 个工作上的根底学习器后,才更新一次元学习器中的参数初始值。
Batched Version Reptile 算法流程
- initialize $\theta$
for iteration=1, 2, … do:
- sample tasks $T_1$, $T_2$, … , $T_n$,
for i=1, 2, … , n do:
- compute $\theta_{i}^{N}=\operatorname{SGD}\left(L_{T_{i}}, {\theta}, {N}\right)$
- end for
- update $\theta \leftarrow \theta+\varepsilon \frac{1}{n} \sum_{i=1}^{n}\left(\theta_{i}^{N}-\theta\right)$
- end for
3.3 Reptile 分类后果
<center>
表 1 Reptile 在 Omniglot 上的分类后果。
</center>
Algorithm | 5-way 1-shot | 5-way 5-shot | 20-way 1-shot | 20-way 5-shot |
---|---|---|---|---|
MAML + Transduction | 98.7 $\pm$ 0.4 $\%$ | 99.9 $\pm$ 0.1 $\%$ | 95.8 $\pm$ 0.3 $\%$ | 98.9 $\pm$ 0.2 $\%$ |
$1^{st}$-order MAML + Transduction | 98.3 $\pm$ 0.5 $\%$ | 99.2 $\pm$ 0.2 $\%$ | 89.4 $\pm$ 0.5 $\%$ | 97.9 $\pm$ 0.1 $\%$ |
Reptile | 95.32 $\pm$ 0.05 $\%$ | 98.87 $\pm$ 0.02 $\%$ | 88.27 $\pm$ 0.30 $\%$ | 97.07 $\pm$ 0.12 $\%$ |
Reptile + Transduction | 97.97 $\pm$ 0.08 $\%$ | 99.47 $\pm$ 0.04 $\%$ | 89.36 $\pm$ 0.20 $\%$ | 97.47 $\pm$ 0.10 $\%$ |
<center>
表 1 Reptile 在 miniImageNet 上的分类后果。
</center>
Algorithm | 5-way 1-shot | 5-way 5-shot |
---|---|---|
MAML + Transduction | 48.70 $\pm$ 1.84 $\%$ | 63.11 $\pm$ 0.92 $\%$ |
$1^{st}$-order MAML + Transduction | 48.07 $\pm$ 1.75 $\%$ | 63.15 $\pm$ 0.91 $\%$ |
Reptile | 45.79 $\pm$ 0.44 $\%$ | 61.98 $\pm$ 0.69 $\%$ |
Reptile + Transduction | 48.21 $\pm$ 0.69 $\%$ | 66.00 $\pm$ 0.62 $\%$ |
更多优质内容请关注公重号:汀丶人工智能
- 参考文献
[1] Reptile: a Scalable Metalearning Algorithm