关于深度学习:gpusharecom基于去噪Transformer的无监督句子编码EMNLP-2021

3次阅读

共计 1806 个字符,预计需要花费 5 分钟才能阅读完成。

文章起源 | 恒源云社区

原文地址 | 论文小记

原文作者 | Mathor


这几天忙里偷闲去社区看了看各位版主一开始发的文章。重点找了我最喜爱的版主 Mathor 的文章,认真一查,居然曾经发了 90 多篇,不愧是社区大佬本佬了!

想着看都看了,那就棘手搬运一下大佬的文章吧!

接下来跟着小编的脚步👣,一起看上来吧~

注释开始

EMNLP2021 Findings 上有一篇名为 TSDAE: Using Transformer-based Sequential Denoising Auto-Encoder for Unsupervised Sentence Embedding Learning 的论文,利用 Transformer 构造无监督训练句子编码,网络架构如下所示:

具体来说,输出的文本增加了一些确定的噪声,例如删除、替换、增加、Mask 一些词等办法。Encoder 须要将含有噪声的句子编码为一个固定大小的向量,而后利用 Decoder 将本来的不带噪声的句子还原。说是这么说,然而其中有十分多细节,首先是训练指标

其中,\(D \)是训练集;\(x = x_1x_2\cdots x_l \)是长度为 lll 的输出句子;\(\tilde{x} \)是 \(x \)增加噪声之后的句子;\(e_t \)是词 \(x_t \)的 word embedding;\(N \)为 Vocabulary size;\(h_t \)是 Decoder 第 \(t \)步输入的 hidden state

不同于原始的 Transformer,作者提出的办法,Decoder 只利用 Encoder 输入的固定大小的向量进行解码,具体来说,Encoder-Decoder 之间的 cross-attention 形式化地示意如下:

其中,\(H^{(k)}\in \mathbb{R}^{t\times d} \)是 Decoder 第 \(k \)层 \(t \)个解码步骤内的 hidden state;\(d \)是句向量的维度(Encoder 输入向量的维度);\([s^T]\in \mathbb{R}^{1\times d} \)是 Encoder 输入的句子(行)向量。从下面的公式咱们能够看出,不管哪一层的 cross-attention,\(K \)和 \(V \)永远都是 \(s^T \),作者这样设计的目标是为了人为给模型增加一个瓶颈,如果 Encoder 编码的句向量 \(s^T \)不够精确,Decoder 就很难解码胜利,换句话说,这样设计是为了使得 Encoder 编码的更加精确。训练完结后如果须要提取句向量只须要用 Encoder 即可

作者通过在 STS 数据集上调参,发现最好的组合办法如下:

  1. 采纳删除单词这种增加噪声的办法,并且比例设置为 60%
  2. 应用 [CLS] 地位的输入作为句向量

RESULTS

从 TSDAE 的后果来看,基本上是拳打 SimCSE,脚踢 BERT-flow

集体总结

如果我是 reviewer,我特地想问的一个问题是:“你们这种办法,与 BART 有什么区别?”

论文源码在 UKPLab/sentence-transformers/,其实 sentence-transformers 曾经把 TSDAE 封装成 pip 包,残缺的训练流程能够参考 Sentence-Transformer 的应用及 fine-tune 教程,在此基础上只须要批改 dataset 和 loss 就能够轻松的训练 TSDAE

# 创立可即时增加噪声的非凡去噪数据集
train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences)

# DataLoader 批量解决数据
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# 应用去噪主动编码器损失
train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True)

# 模型训练
model.fit(train_objectives=[(train_dataloader, train_loss)],
    epochs=1,
    weight_decay=0,
    scheduler='constantlr',
    optimizer_params={'lr': 3e-5},
    show_progress_bar=True
)
正文完
 0