关于深度学习:持续学习常用6种方法总结使ML模型适应新数据的同时保持旧数据的性能

32次阅读

共计 9711 个字符,预计需要花费 25 分钟才能阅读完成。

继续学习是指在不遗记从后面的工作中取得的常识的状况下,按程序学习大量工作的模型。这是一个重要的概念,因为在监督学习的前提下,机器学习模型被训练为针对给定数据集或数据分布的最佳函数。而在事实环境中,数据很少是动态的,可能会发生变化。当面对不可见的数据时,典型的 ML 模型可能会性能降落。这种景象被称为灾难性忘记。

解决这类问题的罕用办法是在蕴含新旧数据的新的更大数据集上对整个模型进行再训练。然而这种做法往往代价昂扬。所以有一个 ML 钻研畛域正在钻研这个问题,基于该畛域的钻研,本文将探讨 6 种办法,使模型能够在放弃旧的性能的同时适应新数据,并防止须要在整个数据集 (旧 + 新) 上进行从新训练。

Prompt

Prompt 想法源于对 GPT 3 的提醒 (短序列的单词) 能够帮忙驱动模型更好地推理和答复。所以在本文中将 Prompt 翻译为提醒。提醒调优是指应用小型可学习的提醒,并将其与理论输出一起作为模型的输出。这容许咱们只在新数据上训练提供提醒的小模型,而无需再训练模型权重。

具体来说,我抉择了应用提醒进行基于文本的密集检索的例子,这个例子改编自 Wang 的文章《Learning to Prompt for continuous Learning》。

该论文的作者应用下图形容了他们的想法:

理论编码的文本输出用作从提醒池中辨认最小匹配对的 key。在将这些标识的提醒输出到模型之前,首先将它们增加到未编码的文本嵌入中。这样做的目标是训练这些提醒来示意新的工作,同时放弃旧的模型不变,这里提醒的很小,大略每个提醒只有 20 个令牌。

 class PromptPool(nn.Module):
     def __init__(self, M = 100, hidden_size = 768, length = 20, N=5):
         super().__init__()
         self.pool = nn.Parameter(torch.rand(M, length, hidden_size), requires_grad=True).float()
         self.keys = nn.Parameter(torch.rand(M, hidden_size), requires_grad=True).float()
         
         self.length = length
         self.hidden = hidden_size
         self.n = N
         
         nn.init.xavier_normal_(self.pool)
         nn.init.xavier_normal_(self.keys)
         
     def init_weights(self, embedding):
         pass
     
     # function to select from pool based on index
     def concat(self, indices, input_embeds):
         subset = self.pool[indices, :] # 2, 2, 20, 768
         
         subset = subset.to("cuda:0").reshape(indices.size(0), 
                                              self.n*self.length, 
                                              self.hidden) # 2, 40, 768
 
         return torch.cat((subset, input_embeds), 1)
     
     # x is cls output
     def query_fn(self, x):
         
         # encode input x to same dim as key using cosine
         x = x / x.norm(dim=1)[:, None]
         k = self.keys / self.keys.norm(dim=1)[:, None]
         
         scores = torch.mm(x, k.transpose(0,1).to("cuda:0"))
         
         # get argmin
         subsets = torch.topk(scores, self.n, 1, False).indices # k smallest
         
         return subsets
 
 pool = PromptPool()

而后咱们应用的经过训练的旧数据模型,训练新的数据,这里只训练提醒局部的权重。

 def train():
     count = 0
     print("*********** Started Training *************")
     
     start = time.time()
     for epoch in range(40):
         model.eval()
         pool.train()
         
         optimizer.zero_grad(set_to_none=True)
         lap = time.time()
         
         for batch in iter(train_dataloader):
             count += 1
             q, p, train_labels = batch
             
             queries_emb = model(input_ids=q['input_ids'].to("cuda:0"),
                                attention_mask=q['attention_mask'].to("cuda:0"))
             passage_emb = model(input_ids=p['input_ids'].to("cuda:0"),
                                attention_mask=p['attention_mask'].to("cuda:0"))      
             
             # pool
             q_idx = pool.query_fn(queries_emb)
             raw_qembedding = model.model.embeddings(input_ids=q['input_ids'].to("cuda:0")) 
             q = pool.concat(indices=q_idx, input_embeds=raw_qembedding)
             
             p_idx = pool.query_fn(passage_emb)
             raw_pembedding = model.model.embeddings(input_ids=p['input_ids'].to("cuda:0")) 
             p = pool.concat(indices=p_idx, input_embeds=raw_pembedding)
             
             qattention_mask = torch.ones(batch_size, q.size(1))
             pattention_mask = torch.ones(batch_size, p.size(1))
             
             queries_emb = model.model(inputs_embeds=q,
                                attention_mask=qattention_mask.to("cuda:0")).last_hidden_state
             passage_emb = model.model(inputs_embeds=p,
                                attention_mask=pattention_mask.to("cuda:0")).last_hidden_state
             
             q_cls = queries_emb[:, pool.n*pool.length+1, :]
             p_cls = passage_emb[:, pool.n*pool.length+1, :]
             
             loss, ql, pl = calc_loss(q_cls, p_cls)                    
             loss.backward()
             
             optimizer.step()
             optimizer.zero_grad(set_to_none=True)
             
             if count % 10 == 0:
                 print("Model Loss:", round(loss.item(),4), \
                       "| QL:", round(ql.item(),4), "| PL:", round(pl.item(),4), \
                       "| Took:", round(time.time() - lap), "seconds\n")
             
                 lap = time.time()
             
             if count % 40 == 0 and count > 0:
                 print("model saved")
                 torch.save(model.state_dict(), model_PATH)
                 torch.save(pool.state_dict(), pool_PATH)
                 
             if count == 4600: return
             
     print("Training Took:", round(time.time() - start), "seconds")
     print("\n*********** Training Complete *************")

训练实现后,后续的推理过程须要将输出与检索到的提醒联合起来。例如这个例子失去了性能—93% 的新数据提醒池,而齐全 (旧 + 新) 训练为—94%。这与原论文中提到的体现相似。然而须要阐明的一点是后果可能会因工作而不同,你应该尝试试验来晓得什么是最好的。

要使此办法成为值得思考的办法,它必须可能在旧数据上保留老模型 > 80% 的性能,同时提醒也应该帮忙模型在新数据上取得良好的性能。

这种办法的毛病是须要应用提醒池,这会减少额定的工夫。这也不是一个永恒的解决方案,然而目前来说是可行的,也或者当前还会有新的办法呈现。

Data Distillation

你可能据说过常识蒸馏一词,这是一种应用来自老师模型的权重来领导和训练较小规模模型的技术。数据蒸馏(Data Distillation)的工作原理也相似,它是应用来自实在数据的权重来训练更小的数据子集。因为数据集的要害信号被提炼并稀释为更小的数据集,咱们对新数据的训练只须要提供一些提炼的数据以放弃旧的性能。

在此示例中,我将数据蒸馏利用于密集检索(文本)工作。目前看没有其他人在这个畛域应用这种办法,所以后果可能不是最好的,但如果你在文本分类上应用这种办法应该会失去不错的后果。

实质上,文本数据蒸馏的想法源于 Li 的一篇题为 Data Distillation for Text Classification 的论文,该论文的灵感来自 Wang 的 Dataset Distillation,他对图像数据进行了蒸馏。Li 用下图形容了文本数据蒸馏的工作:

依据论文,首先将一批蒸馏数据输出到模型以更新其权重。而后应用实在数据评估更新后的模型,并将信号反向流传到蒸馏数据集。该论文在 8 个公共基准数据集上报告了良好的分类后果(> 80% 准确率)。

依照提出的想法,我做了一些小的改变,应用了一批蒸馏数据和多个实在数据。以下是为密集检索训练创立蒸馏数据的代码:

 class DistilledData(nn.Module):
     def __init__(self, num_labels, M, q_len=64, hidden_size=768):
         super().__init__()
         self.num_samples = M
         self.q_len = q_len
         self.num_labels = num_labels
         self.data = nn.Parameter(torch.rand(num_labels, M, q_len, hidden_size), requires_grad=True) # i.e. shape: 1000, 4, 64, 768
     
     # init using model embedding, xavier, or load from state dict
     def init_weights(self, model, path=None):
         if model:
             self.data.requires_grad = False
             print("Init weights using model embedding")
             raw_embedding = model.model.get_input_embeddings()
             soft_embeds = raw_embedding.weight[:, :].clone().detach()
             nums = soft_embeds.size(0)
             for i1 in range(self.num_labels):
                 for i2 in range(self.num_samples):
                     for i3 in range(self.q_len):
                         random_idx = random.randint(0, nums-1)
                         self.data[i1, i2, i3, :] = soft_embeds[random_idx, :]
             print(self.data.shape)
             self.data.requires_grad = True
             
         if not path:
             nn.init.xavier_normal_(self.data)
         else:
             distilled_data.load_state_dict(torch.load(path), strict=False)
     
     # function to sample a passage and positive sample as in the article, i am doing dense retrieval
     def get_sample(self, label):
         q_idx = random.randint(0, self.num_samples-1)
         sampled_dist_q = self.data[label, q_idx, :, :]
         
         p_idx = random.randint(0, self.num_samples-1)
         while q_idx == p_idx: 
             p_idx = random.randint(0, self.num_samples-1)
         sampled_dist_p = self.data[label, p_idx, :, :]
         
         return sampled_dist_q, sampled_dist_p, q_idx, p_idx
       

这是将信号提取到蒸馏数据上的代码

 def distll_train(chunk_size=32):
     count, times = 0, 0
     print("*********** Started Training *************")
     start = time.time()
     lap = time.time()
     
     for epoch in range(40):        
         distilled_data.train()
         
         for batch in iter(train_dataloader):
             count += 1
             # get real query, pos, label, distilled data query, distilled data pos, ... from batch
             q, p, train_labels, dq, dp, q_indexes, p_indexes = batch
             
             for idx in range(0, dq['input_ids'].size(0), chunk_size):
                 model.train()
                 
                 with torch.enable_grad():   
                     # train on distiled data first
                     x1 = dq['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
                     x2 = dp['input_ids'][idx:idx+chunk_size].clone().detach().requires_grad_(True)
                     q_emb = model(inputs_embeds=x1.to("cuda:0"),
                                  attention_mask=dq['attention_mask'][idx:idx+chunk_size].to("cuda:0")).cpu()
                     p_emb = model(inputs_embeds=x2.to("cuda:0"),
                                   attention_mask=dp['attention_mask'][idx:idx+chunk_size].to("cuda:0"))
                     loss = default_loss(q_emb.to("cuda:0"), p_emb)
                     del q_emb, p_emb
                     
                     loss.backward(retain_graph=True, create_graph=False)
                     state_dict = model.state_dict()
                     
                     # update model weights
                     with torch.no_grad():
                         for idx, param in enumerate(model.parameters()):
                             if param.requires_grad and not param.grad is None:
                                 param.data -= (param.grad*3e-5)
 
                 # real data
                 model.eval()
                 q_embs = []
                 p_embs = []
                 for k in range(0, len(q['input_ids']), chunk_size):
                     with torch.no_grad():
                         q_emb = model(input_ids=q['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
                         p_emb = model(input_ids=p['input_ids'][k:k+chunk_size].to("cuda:0"),).cpu()
                         q_embs.append(q_emb)
                         p_embs.append(p_emb)
                 q_embs = torch.cat(q_embs, 0)
                 p_embs = torch.cat(p_embs, 0)
                 r_loss = default_loss(q_embs.to("cuda:0"), p_embs.to("cuda:0"))
                 del q_embs, p_embs
                 
                 # distill backward
                 if count % 2 == 0:
                     d_grad = torch.autograd.grad(inputs=[x1.to("cuda:0")],#, x2.to("cuda:0")],
                                                 outputs=loss,
                                                 grad_outputs=r_loss)
                     indexes = q_indexes
                 else:
                     d_grad = torch.autograd.grad(inputs=[x2.to("cuda:0")],
                             outputs=loss,
                             grad_outputs=r_loss)
                     indexes = p_indexes
                 loss.detach()
                 r_loss.detach()
 
                 grads = torch.zeros(distilled_data.data.shape) # lbl, 10, 100, 768
                 for i, k in enumerate(indexes):
                     grads[train_labels[i], k, :, :] = grads[train_labels[i], k, :, :].to("cuda:0") \
                                                     + d_grad[0][i, :, :]
                 distilled_data.data.grad = grads
                 data_optimizer.step()
                 data_optimizer.zero_grad(set_to_none=True)
 
                 model.load_state_dict(state_dict)
                 model_optimizer.step()
                 model_optimizer.zero_grad(set_to_none=True)
                 
                 if count % 10 == 0:
                     print("Count:", count ,"| Data:", round(loss.item(), 4), "| Model:", \
                           round(r_loss.item(),4), "| Time:", round(time.time() - lap, 4))
                     # print()
                     lap = time.time()
 
                 if count % 100 == 0:  
                     torch.save(model.state_dict(), model_PATH)
                     torch.save(distilled_data.state_dict(), distill_PATH)
 
                 if loss < 0.1 and r_loss < 1: 
                     times += 1
 
                 if times > 100:
                     print("Training Took:", round(time.time() - start), "seconds")
                     print("\n*********** Training Complete *************")
                     return
                 del loss, r_loss, grads, q, p, train_labels, dq, dp, x1, x2, state_dict
                 
     print("Training Took:", round(time.time() - start), "seconds")
     print("\n*********** Training Complete *************")

这里省略了数据加载等代码,训练完蒸馏的数据后,咱们能够通过在其上训练新模型来应用它,例如将其与新数据合并一起训练。

依据我的试验,一个在蒸馏数据上训练的模型 (每个标签只蕴含 4 个样本) 取得了 66% 的最佳性能,而一个齐全在原始数据上训练的模型也是失去了 66% 的最佳性能。而未经训练的一般模型失去 45% 的性能。就像下面提到的这些数字对于密集检索工作可能不太好,分类数据上会好很多。

要使此办法成为在调整模型以适应新数据时值是一个有用的办法,须要可能提取出比原始数据小得多的数据集(即~ 1%)。通过提炼的数据也可能给你一个略低于或等于被动学习办法的体现。

这个办法的长处是能够创立用于永恒应用的蒸馏数据。毛病是提取的数据没有可解释性,并且须要额定的训练工夫。

Curriculum/Active training

Curriculum training 是一种办法,训练时向模型提供训练样本的难度逐步变大。在对新数据进行训练时,此办法须要人工的对工作进行标注,将工作分为简略、中等或艰难,而后对数据进行采样。为了了解模型的简略、中等或艰难意味着什么,我以这张图片为例:

这是在分类工作中的混同矩阵,艰难样本是假阳性(False Positive),是指模型预测为 True 的可能性很高,但实际上不是 True 的样本。中等样本是那些具备中到高的正确性可能性但低于预测阈值的 True Negative。而简略样本则是那些可能性较低的 True Positive/Negative。

Maximally Interfered Retrieval

这是 Rahaf 在题为“Online Continual Learning with Maximally Interfered Retrieval”的论文(1908.04742)中介绍的一种办法。次要思维是,对于正在训练的每个新数据批次,如果针对较新数据更新模型权重,将须要辨认在损失值方面受影响最大的旧样本。保留由旧数据组成的无限大小的内存,并检索最大烦扰的样本以及每个新数据批次以一起训练。

这篇论文在继续学习畛域是一篇成熟的论文,并且有很多援用,因而可能实用于您的案例。

Retrieval Augmentation

检索加强(Retrieval Augmentation)是指通过从汇合中检索我的项目来裁减输出、样本等的技术。这是一个广泛的概念而不是一个特定的技术。咱们到目前为止所探讨的办法,大多数都在肯定水平都是检索相干的操作。Izacard 的题为 Few-shot Learning with Retrieval Augmented Language Models 的论文应用更小的模型取得了杰出的少样本 学习的性能。检索加强也用于许多其余状况,例如单词生成或答复事实问题。

扩大模型

在训练时应用附加层是最常见也最简略的办法,然而不肯定无效,所以在这里不进行具体的探讨,这里的一个例子是 Lewis 的 Efficient Few-Shot Learning without Prompts。应用附加层通常是在新旧数据上取得良好性能的最简略但通过尝试和测试的办法。次要思维是放弃模型权重固定,并通过分类损失在新数据上训练一层或几层。有趣味能够参考他们的 Github(https://github.com/huggingfac…)

总结

在本文中,我介绍了在新数据上训练模型时能够应用的 6 种办法。与平常一样应该进行试验并决定哪种办法最适宜,然而须要留神的是,除了我下面的办法外还有很多办法,例如数据蒸馏是计算机视觉中的一个沉闷畛域,你能够找到很多对于它的论文。最初阐明的一点是:要使这些办法有价值,它们应该在旧数据和新数据上同时取得良好的性能。

https://avoid.overfit.cn/post/2183a89bd513423982e9bccefe1a2896

作者:Gan Yun Tian

正文完
 0