域适应简介

域适应是迁徙学习中最常见的问题之一,域不同但工作雷同,且源域数据有标签,指标域数据没有标签或者很少数据有标签。
域适应通过将源域和指标域的特色投影到类似的特色空间,这样就能够拿源域的分类器对指标域进行分类了

上面拿二分类做阐明,如下图:

图中红圈是源域,蓝圈是指标域,圆圈和叉是不同特色的数据,源域的分类器将源域的数据分为两类,即虚线所示。
此时如果拿源域的分类器在指标域上分类,从图中能够看到,成果很差。
 
那怎么办呢,有一种办法就是把源域和指标域的散布对齐,如图片左边所示,源域指标域的散布类似(即类似特色的数据分布在相近的地位),这样就能够间接拿源域的分类器对指标域进行分类了。

训练过程域反抗生成网络 GAN 类似
同时训练两个模型:一个用来提取指标域特色 MT,和一个用来判断特色来自源域还是指标域的域分别器 D,MT 的训练过程是最大化 D 产生谬误的过程,即MT提取的特色让 D 分辨不进去是来自源域还是指标域。

指标域特征提取器 MT 和域判断器 D 互为对手:D 学习去判断特色是来自源域还是指标域,MT 学习让本人提取的特色更靠近源域提取出的特色。指标域特征提取器 MT 能够被认为是一个伪造团队,试图产生假货并在不被发现的状况下应用它,而域判断器 D 相似于警察,试图检测假币。在这个游戏中的竞争驱使两个团队改良他们的办法,直到真假难分为止。

对抗性域适应

数据的选取

为了成果好,训练简略,我选取 mnist 数据集中 0、1 的数据作为源域,2、3 的数据作为指标域。源域和指标域的数据各 10000 个。
在训练时,源域可取得数据和标签,而指标域只能取得数据,没有标签,来模仿域适应的背景。指标域的标签仅在测试精度时应用。

网络

1.源域特征提取器 MS、指标域特征提取器 MT。所谓特征提取器,实际上就是将辨认 mnist 的网络去掉最初一层分类层。

        (encoder): Sequential (    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))    (1): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))    (2): ReLU ()    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))    (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))    (5): ReLU ()    )    (fc1): Linear (64 * 4 * 4 -> 512)

把这个网络的输入看作是提取出的特色

2.分类器C。理论就是辨认 mnist 的网络最初一层分类层,一个简略的全连贯网络。

        Classifier (    (fc2): Linear (512 -> 2)    )

3.域识别器 D。依据特征提取器的输入来判断数据来自源域还是指标域,输入 0 代表来自源域,输入 1 代表来自指标域。

        Discriminator (     (layer): Sequential (    (0): Linear (512 -> 512)    (1): Linear (512 -> 512)    (2): Linear (512 -> 2)    ))

 

训练过程

训练MS、C

首先,在源域上训练特征提取器 MS 和分类器 C

训练过程和个别训练过程类似,只不过把整个网络分成了两局部来训练、优化。

def train_MS_C(loader_ms):    # 模型    MS = Encoder()    C = Classifier()    # 优化器    o_ms = optim.SGD(MS.parameters(), lr=0.03)    o_c = optim.SGD(C.parameters(), lr=0.03)    criterion = nn.CrossEntropyLoss()  # 计算损失    for j in range(1):        print(j)        # 训练        for i, (images, labels) in enumerate(loader_ms):            o_ms.zero_grad()            o_c.zero_grad()            outputs_mid = MS(images)            outputs = C(outputs_mid)            loss = criterion(outputs, labels)            loss.backward()            o_ms.step()  # 优化参数            o_c.step()            if i % 100 == 0:                print(i)                print('current loss : %.5f' % loss.data.item())    # 保留模型    np.save(params.MS_save_dir, MS.get_w())    np.save(params.C_save_dir, C.get_w())

训练实现后,在源域的精确度为 0.9985
如果间接拿源域的特征提取器和分类器对指标域进行分类的话,精确度只有 0.5840

固定MS和C,训练MT和D

接着,固定 MS 和 C 不变,即不扭转它们的网络权重,在源域和指标域上反抗式学习指标域特征提取器 MT 和域识别器 D
1.用 MS 初始化 MT,这样开始指标域会取得一个不错的精度 0.5840,接着在这个根底上训练,更容易收敛到好的方向,并且收敛过程也快了。

MT.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())

def train_MT_D(loader_ms, loader_mt):    # 模型    MS = Encoder()    MT = Encoder()    D = Discriminator()    # 加载模型    MS.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())    if params.first_train:        params.first_train = False        # 第一次训练        # MT用MS的权重初始化        MT.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())    else:        MT.update_w(np.load(params.MT_save_dir, encoding='bytes', allow_pickle=True).item())        D.update_w(np.load(params.D_save_dir, encoding='bytes', allow_pickle=True).item())    # 优化器    o_mt = optim.SGD(MT.parameters(), lr=0.00001)    o_d = optim.SGD(D.parameters(), lr=0.00001)    criterion = nn.CrossEntropyLoss()  # 计算损失    # 训练    for j in range(1):        print(j)        # 训练D 域分别器        data_zip = zip(loader_ms, loader_mt)        for i, ((images_s, labels_s), (images_t, labels_t)) in enumerate(data_zip):            ################对域分别器D的训练            # 提取的特色            f_s = MS(images_s)            f_t = MT(images_t)            f_cat = torch.cat((f_s, f_t), 0)            # 域分别器分别后果            out_D = D(f_cat.detach())            predicts_D = torch.max(out_D.data, 1)[1]            if i == 0:                print('域分别器的分别后果')                print(predicts_D)            # 结构损失比照用的标签            len_s = len(labels_s)            len_t = len(labels_t)            temp1 = torch.zeros(len_s)            temp2 = torch.ones(len_t)            lab_D = torch.cat((temp1, temp2), 0).long()            # 梯度置0            o_d.zero_grad()            # 计算loss            loss_D = criterion(out_D, lab_D)            # 反向流传            loss_D.backward()            # 优化网络            o_d.step()            ##############################对指标域特征提取器MT的训练            # 提取的特色            f_t = MT(images_t)            # 域分别器分别后果            d_t = D(f_t)            # 结构计算损失的outputs、labels            out_MT = d_t            predicts_MT = torch.max(out_MT.data, 1)[1]            lab_MT = torch.zeros(len_t).long()            # 梯度置0            o_mt.zero_grad()            # 计算loss            loss_MT = criterion(out_MT, lab_MT)            # 反向流传            loss_MT.backward()            # 优化网络            o_mt.step()            if i % 100 == 0:                print(i)                print('current loss_D : %.5f' % loss_D.data.item())                print('current loss_MT : %.5f' % loss_MT.data.item())    # 保留模型    np.save(params.MT_save_dir, MT.get_w())    np.save(params.D_save_dir, D.get_w())

用MT和C在指标域上分类

最初用训练好的指标域特征提取器 MT 和分类器 C 来在指标域上分类

def test_MT_C(loader_mt):    MT = Encoder()    C = Classifier()    # 加载模型    MT.update_w(np.load(params.MT_save_dir, encoding='bytes', allow_pickle=True).item())    C.update_w(np.load(params.C_save_dir, encoding='bytes', allow_pickle=True).item())    correct = 0    for images, labels in loader_mt:        outputs_mid = MT(images)        outputs = C(outputs_mid)        _, predicts = torch.max(outputs.data, 1)        correct += (predicts == labels).sum()    total = len(loader_mt.dataset)    print('MT+C  Accuracy: %.4f' % (1.0 * correct / total))

试验后果

拿源域的特征提取器和分类器对指标域进行分类的话,精确度只有 0.5840

下图是域分别器 D 的后果,前半部分的输出是源域的特色,后半局部的输出是指标域的特色,当初 D 大部分都能判断正确。

训练几轮后,精确度回升了一点

D 对域的分辨能力降落了,大部分指标域的输出都判断为源域的。

在训练 40 轮后,精确度在 0.9 左近稳定,与开始的 0.5840 相比,精确度晋升了很多

D 无奈分辨源域和指标域了,将所有输出都辨认为源域的。

代码地址

https://momodel.cn/explore/5f1574360a2fac574eb9c3f6?type=app
 

参考

Adversarial Discriminative Domain Adaptation
https://blog.csdn.net/sinat_29381299/article/details/73504196
https://github.com/corenel/pytorch-adda