乐趣区

关于ai开发:鹅厂专家讲透AI文本生成解码策略与代码实现

腾小云导读

本文以 huggingface-transformers 的文本生成解码代码为例,对文本生成罕用的五种解码策略 greedy search、beam search、sample、sample and rank & beam sample、group beam search 进行逐行解读。每一大节首先会介绍对应解码策略的原理,接着给出供大家疾速上手的代码示例,并逐层介绍调用过程,最初给出所应用到的所有类之间调用的时序图。由简到繁再到简,帮忙大家建设起一个整体的意识,并且可能疾速利用。干货较多,欢送浏览并进行实际尝试。

目录

1 总体介绍

2 greedy search

2.1 原理介绍

2.2 疾速上手

2.3 代码解读

2.4 整体流程

3 beam search

3.1 原理介绍

3.2 疾速上手

3.3 代码解读

3.4 整体流程

4 sample

4.1 原理介绍

4.2 疾速上手

4.3 代码解读

4.4 整体流程

5 sample and rank & beam sample

5.1 原理介绍

5.2 疾速上手

5.3 代码解读

5.4 整体流程

6 group beam search

6.1 原理介绍

6.2 疾速上手

6.3 代码解读

6.4 整体流程

7 总结

8 支流模型计划

01、总体介绍

在 T5/GPT 等自回归模型中,解码策略间接影响到模型输入的成果。在解码第 t 个 token w 时,模型依赖后面的 t-1 个 token,计算概率分布 P(w∣w1:t−1)。依据该概率分布,研究者们设计了各式各样的解码策略,每一种解码策略都对应了一个或多个相干的参数,多种参数糅合在一起,容易让人摸不着头脑。在对应官网提供的 API 中,咱们能够看到也提供了一些用于调整解码策略的参数,如 temperature、top_p 等。

02、greedy search

2.1 原理介绍

最简略的策略就是 greedy decoding,即每步抉择概率最大的 token:

。如上图所示,从单词 The 开始,该策略每步都会抉择下一步概率最大的词,最初会失去输入序列 The nice woman,总概率是 0.5 * 0.4 = 0.2。greedy decoding 速度最快,也有如下几个毛病:

一、它可能会错过全局概率最大的序列。比方上图中,The dog has 的总概率更大,是 0.4 * 0.9 = 0.36。二、因为短少随机性,模型在输入一个反复的 token 之后,有较大可能陷入反复输入序列的循环。三、greedy 解码形式十分靠近模型训练时候的 objective,因而容易复述训练数据,短少了创造性。

2.2 疾速上手

# 环境:python3.9、torch1.13.1、transformers4.26.1
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    StoppingCriteriaList,
    MaxLengthCriteria,
)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

# set pad_token_id to eos_token_id because GPT2 does not have a PAD token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

input_prompt = "It might be possible to"
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

# instantiate logits processors
logits_processor = LogitsProcessorList(
    [MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
        RepetitionPenaltyLogitsProcessor(1.2),
    ]
)
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

outputs = model.greedy_search(input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)
-------------------------------------------------output-------------------------------------------------
['It might be possible to get a better understanding of the nature of this phenomenon, but it is not']

疾速上手的代码参考:Generation,更具体的参数介绍也可从中获取。

链接:https://huggingface.co/docs/transformers/main_classes/text_generation

2.3 代码解读

次要针对疾速上手的第 30-32 行代码调用的 greedy_search 办法进行解读。

代码地址:

transformers/utils.py at v4.26.1 · huggingface/transformers · GitHub

2.3.1 根本设置,对后续须要应用的变量进行初始化

logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
    warnings.warn(
        "`max_length` is deprecated in this function, use"
        "`stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
        UserWarning,
    )
    stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
    eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (output_attentions if output_attentions is not None else self.generation_config.output_attentions)
output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states)
return_dict_in_generate = (
    return_dict_in_generate
    if return_dict_in_generate is not None
    else self.generation_config.return_dict_in_generate
)

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
    encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
    encoder_hidden_states = (model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
    )
1- 1 行:获取 logits_processor,用于后续对 logits 进行预处理;2- 9 行:获取 stopping_criteria,用于后续判断何时进行解码。若设置了解码最大长度,则验证已获取的 stopping_criteria 是否设置正确;10-11 行:获取 pad_token_id、eos_token_id,用于 padding 和辨认句子完结地位;12-13 行:若 eos_token_id 为 int 类型,则将其转换为 list,这么做能够让多个 token 都作为 eos_token,当 eos_token 有多个时,获取的 eos_token_id 则为一个 list,因而其为 int 类型时,须要进行转换;14-19 行:获取 output_scores、output_attentions、output_hidden_states,这三个变量均为 bool 类型,用于决定后续是否须要输入 scores、attentions、hidden_states(生成句子的得分、decoder 每一层的注意力矩阵、decoder 每一层的暗藏状态);20-31 行:获取 return_dict_in_generate,用于判断是否须要将 4. 中几个变量返回给调用方。若须要且对应变量为 True,则初始化 scores、decoder_attentions、cross_attentions、decoder_hidden_states;32-38 行:若模型为 encoder-decoder 架构,则获取 encoder 的 attention 和 hidden_states。

2.3.2 从 bos_token 开始解码

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

this_peer_finished = False  # used by synced_gpus only
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    # prepare model inputs
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

    # forward pass to get next token
    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]

    # pre-process distribution
    next_tokens_scores = logits_processor(input_ids, next_token_logits)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_tokens_scores,)
        if output_attentions:
            decoder_attentions += ((outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += ((outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # argmax
    next_tokens = torch.argmax(next_tokens_scores, dim=-1)

    # finished sentences should have their next token be a padding token
    if eos_token_id is not None:
        if pad_token_id is None:
            raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

    # update generated ids, model inputs, and length for next step
    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
    model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder)

    # if eos_token was found in one sentence, set sentence to finished
    if eos_token_id is not None:
        unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())

    # stop when each sentence is finished, or if we exceed the maximum length
    if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True
1- 2 行:初始化 unfinished_sequences,维度为[batch_size],用于判断 batch 内句子是否已全副解码实现,值为 1 示意未解码实现,0 示意已解码实现;4- 4 行:初始化 this_peer_finished 为 False,用于阐明以后 gpu 并未实现 batch 内所有句子的解码,仅在 synced_gpus 为 True 时起作用。synced_gpus 为是否须要进行 gpu 间同步的标记;6-14 行:若须要进行 gpu 间的同步,首先初始化 this_peer_finished_flag,若以后 gpu 已实现 batch 内所有句子的解码,则赋值为 0.0,否则赋值为 1.0。之后将所有 gpu 的 this_peer_finished_flag 变量进行相加,若其值为 0.0,阐明所有 gpu 都已实现解码,此时能够完结解码;19-25 行:获取模型输入后果;27-28 行:如果须要进行 gpu 间的同步,且以后 gpu 已对 batch 内所有句子解码实现,则跳过;30-33 行:获取 next_token_logits,维度为[batch_size, vocab_size],即预测的下一个 token 的 logits。之后调用 1. 中初始化的 logits_processor 对 next_token_logits 进行预处理,logits_processor 为 LogitsProcessorList 的实例。

代码:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

class LogitsProcessorList(list):
    """
    This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a
    `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
    [`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
    """

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
        for processor in self:
            function_args = inspect.signature(processor.__call__).parameters
            if len(function_args) > 2:
                if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
                    raise ValueError(f"Make sure that all the required parameters: {list(function_args.keys())} for"
                        f"{processor.__class__} are passed to the logits processor."
                    )
                scores = processor(input_ids, scores, **kwargs)
            else:
                scores = processor(input_ids, scores)
        return scores

此处会调用__call__办法,参数 input_ids 为已生成的序列,scores 为下一步预测 token 的得分。

10-21 行:循环调用 LogitsProcessor 中的 processor。对于每一次循环,首先获取 processor __call__办法的参数,若参数个数大于 2,对参数进行查看,确保所有参数都正确传入了,之后再进行调用。若参数个数小于等于 2,则间接调用。最初返回解决后的得分。

这里介绍疾速上手中应用的两种预处理办法最小长度和反复词惩办对应的 processor。

· 最小长度

代码:transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

class MinLengthLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
    Args:
        min_length (`int`):
            The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
        eos_token_id (`Union[int, List[int]]`):
            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
    """

    def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
        if not isinstance(min_length, int) or min_length < 0:
            raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
            raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

        self.min_length = min_length
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        cur_len = input_ids.shape[-1]
        if cur_len < self.min_length:
            for i in self.eos_token_id:
                scores[:, i] = -float("inf")
        return scores
上文中调用的__call__办法,即跳转到这里的 23 行;24-28 行:获取以后已生成序列的长度。若以后长度小于预设的最小长度,则遍历所有 eos_token,将其得分设为 -inf。这样就能够保障在以后步解码的后果不会是 eos_token。

· 反复词惩办

代码:transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
    Args:
        repetition_penalty (`float`):
            The parameter for repetition penalty. 1.0 means no penalty. See [this
            paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
    """

    def __init__(self, penalty: float):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        score = torch.gather(scores, 1, input_ids)

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)

        scores.scatter_(1, input_ids, score)
        return scores
上文中调用的__call__办法,即跳转到这里的 16 行;17-17 行:input_ids 是已生成的序列,scores 是以后步预测 token 的得分,维度为[batch_size, vocab_size],gather 相当于是从 scores 里获取已生成 token 的得分 19-20 行:如果已生成 token 的 score < 0,就乘上 penalty;如果 score > 0,就除以 penalty。所以如果 penalty 等于 1.0,相当于 score 没有变动,即没有惩办。当 0.0 < penalty < 1.0,已生成的词的得分会被减少,此时为激励反复词生成。当 penalty > 1.0,已生成词的得分就会被放大,此时为惩办反复词生成;22-22 行:把惩办过的 score 从新赋值回 scores;35-51 行:对 scores、attentions、hidden_states 进行从新赋值;53-60 行:获取 next_tokens,维度为[batch_size],即预测的下一个 token id。之后对 next_tokens 进行从新赋值,若以后句子已解码实现,则将其从新赋值为 pad_token_id,否则不变;62-66 行:更新 input_ids,即已生成的序列,将以后预测的 token 拼接到之前预测的序列之后。之后更新 model_kwargs,如对之前已生成 token 的 key value 缓存等信息进行更新,用于下一次预测;68-71 行:更新 unfinished_sequences,因为 eos_token_id 为一个 list,所以只有 next_tokens 为 eos_token_id 中的任意一个,则都代表已解码实现;72-77 行:判断是否能够完结解码,若 unfinished_sequences 的最大值为 0,阐明 batch 内所有句子已解码实现,能够完结解码了。或者满足了进行条件,也能够完结解码,调用 stopping_criteria 函数的返回值为一个 bool 值,代表是否满足进行条件。另外对是否须要进行 gpu 间的同步进行别离解决,若不须要,则间接完结循环,若须要则设置 this_peer_finished 为 True,表明以后 gpu 已对 batch 内所有句子实现解码。

2.3.3 解码完结,返回后果

if return_dict_in_generate:
    if self.config.is_encoder_decoder:
        return GreedySearchEncoderDecoderOutput(
            sequences=input_ids,
            scores=scores,
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
        )
    else:
        return GreedySearchDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
else:
    return input_ids

若须要返回生成过程中的具体后果,则依据架构为 encoder-decoder 和 decoder-only 别离返回对应 dict,否则间接返回预测序列;

2.4 整体流程

整体流程如上面的时序图所示

03、beam search

3.1 原理介绍

为了解决 greedy decoding 可能错过全局最大概率序列的问题,beam search 策略常常会被采纳,即保护 beam=n,保留以后最佳的 n 个序列,并且对于每个序列,都在计算最好的 n 个 next token,而后再从 n*n 个后果中,保留 n 个概率乘积最大的序列。比方上图中,假如 beam=2,从 The 开始,会保留 [The dog, The nice] 两个序列,接着每个序列选取 2 个最佳的 next token,失去 4 个序列,再从中抉择 2 个最佳序列[The dog has, The nice woman]。然而,beam Search 有以下毛病:

一、在 text generation 中,个别将[EOS] token 视为文本的结尾,也就是 absorbing state。如果某个候选序列达到这个 absorbing state,就不再扩大它。这就会造成 Beam Search 通常会偏向于更短的序列,因为长序列算概率乘积后,数值会绝对短序列更小。因而,个别会在得分函数中引入 length normalization 对长度进行归一化。常见办法是引入 ∈[0,1],= 0 不归一化。=1,规范的长度归一化。二、因为短少随机性,beam search 依然很可能掉入反复序列的循环。因此一些工作引入了 n-grams penalty 来缓解。最常见的办法是通过将曾经看到的 n-gram 的下一个单词的概率设置为 0,来确保没有 n-gram 呈现两次。n 是一个超参数,如果 n 设为 2,则 2-gram 序列,比方 New York 不会在解码中呈现两次。三、最初,相比于人类语句个别不太可预测,beam search 生成的序列短少惊喜,因而在须要创造性的生成场景中不是十分适合。

3.2 疾速上手

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    LogitsProcessorList,
    NoRepeatNGramLogitsProcessor,
    BeamSearchScorer,
)
import torch

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to Chinese: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id

# add encoder_outputs to model keyword arguments
model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
    )
}

# instantiate beam scorer
beam_scorer = BeamSearchScorer(
    batch_size=1,
    num_beams=num_beams,
    num_beam_hyps_to_keep=2,
    device=model.device,
)

# instantiate logits processors
logits_processor = LogitsProcessorList(
    [NoRepeatNGramLogitsProcessor(2),
    ]
)

outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True))
print(result)
-------------------------------------------------output-------------------------------------------------
['Wie alt bist du?']

3.3 代码解读

次要针对疾速上手的第 45 行代码调用的 beam_search 办法进行解读

代码地址:

transformers/utils.py at v4.26.1 · huggingface/transformers · GitHub

3.3.1 根本设置,对后续须要应用的变量进行初始化

batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape

if num_beams * batch_size != batch_beam_size:
    raise ValueError(f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
    )

beam_indices = (tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)

这一步与 greedy search 基本一致,区别在于须要额定初始化一些用于 beam search 的变量。

1- 2 行:获取 batch_size 和候选门路个数;4- 9 行:参数查看,batch_beam_size 必须等于 batch_size * num_beams,这也是实现 beam search 算法的一种具体形式,将每条候选门路都当作 batch 内的一条样本,别离进行解码;11-13 行:beam_indices 为所有候选存储最初一个预测的 token 所在门路的每一步门路下标。

3.3.2 从 bos_token 开始解码

# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False  # used by synced_gpus only
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        cur_len = cur_len + 1
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]
    # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
    # cannot be generated both before and after the `nn.functional.log_softmax` operation.
    next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
    next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)

    next_token_scores_processed = logits_processor(input_ids, next_token_scores)
    next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_token_scores_processed,)
        if output_attentions:
            decoder_attentions += ((outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += ((outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # reshape for beam search
    vocab_size = next_token_scores.shape[-1]
    next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

    # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
    next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True)

    next_indices = torch_int_div(next_tokens, vocab_size)
    next_tokens = next_tokens % vocab_size

    # stateless
    beam_outputs = beam_scorer.process(
        input_ids,
        next_token_scores,
        next_tokens,
        next_indices,
        pad_token_id=pad_token_id,
        eos_token_id=eos_token_id,
        beam_indices=beam_indices,
    )

    beam_scores = beam_outputs["next_beam_scores"]
    beam_next_tokens = beam_outputs["next_beam_tokens"]
    beam_idx = beam_outputs["next_beam_indices"]

    input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

    model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder)
    if model_kwargs["past_key_values"] is not None:
        model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)

    if return_dict_in_generate and output_scores:
        beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

    # increase cur_len
    cur_len = cur_len + 1

    if beam_scorer.is_done or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
    input_ids,
    beam_scores,
    next_tokens,
    next_indices,
    pad_token_id=pad_token_id,
    eos_token_id=eos_token_id,
    max_length=stopping_criteria.max_length,
    beam_indices=beam_indices,
)
1- 5 行:初始化 beam_scores,维度为[batch_size, num_beams],首先赋值为 0,之后将除第一条候选门路之外的门路分数均赋值为 -1e9,在 7)中将会介绍这么做的起因,最初将维度变换为[batch_size * num_beams],不便后续的矩阵运算;7-32 行:与 greedy search 基本一致;33-35 行:针对 Marian 模型进行非凡解决,该模型不容许在进行 log_softmax 之前和之后生成 pad token;36-41 行:应用 log_softmax 对 next_token_logits 计算概率值。之后对 next_token_scores 进行预处理。最初将预处理后的以后预测 token 的得分与之前预测序列的得分相加,作为该候选门路的以后得分。这里对疾速上手中用到的 n-gram 惩办预处理进行介绍。

代码:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

class NoRepeatNGramLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that enforces no repetition of n-grams. See
    [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
    Args:
        ngram_size (`int`):
            All ngrams of size `ngram_size` can only occur once.
    """

    def __init__(self, ngram_size: int):
        if not isinstance(ngram_size, int) or ngram_size <= 0:
            raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
        self.ngram_size = ngram_size

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        num_batch_hypotheses = scores.shape[0]
        cur_len = input_ids.shape[-1]
        banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)

        for i, banned_tokens in enumerate(banned_batch_tokens):
            scores[i, banned_tokens] = -float("inf")

        return scores
16-17 行:获取 batch_size 和已生成序列长度;18-18 行:调用 _calc_banned_ngram_tokens 办法,获取以后步须要禁止生成的 token 序列,如果生成了该 token 序列中的任意一个 token,都会和之前时刻生成的 token 组成一个已生成的 ngram,所以只须要获取以后步禁止生成的 token 即可实现禁止生成已生成过的 ngram 的性能。
def _calc_banned_ngram_tokens(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int) -> List[Iterable[int]]:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)

    banned_tokens = [_get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]
    return banned_tokens
4- 7 行:如果(以后已生成序列的长度 + 1) < 须要禁用的 ngram 的长度,+ 1 指的是加上以后步预测的 token,阐明必然还没有生成 ngram,那么也不须要禁用任何 ngram;9- 9 行:调用 _get_ngrams 办法,获取已生成的 ngram。
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
    return generated_ngrams
2- 2 行:为每个样本初始化一个 dict,用来保留曾经生成的 ngram;3- 6 行:首先遍历每个样本,gen_tokens 为已生成的序列,generated_ngram 用来以后样本已生成的 ngram。之后通过 gen_tokens[i:] for i in range(ngram_size) 这行代码来生成已生成序列的 ngram,通过以下例子能够很疾速地了解这行代码。
>>> gen_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> for i in range(2):
...     print(gen_tokens[i:])
... 
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> for ngram in zip(*[gen_tokens[i:] for i in range(2)]):
...     print(ngram)
... 
(1, 2)
(2, 3)
(3, 4)
(4, 5)
(5, 6)
(6, 7)
(7, 8)
(8, 9)
(9, 10)
7- 9 行:以后 ngram 除最初一个 token 外的序列作为 key,即前缀,最初一个 token 作为 value,退出到 generated_ngram 中。最初返回所有样本已生成的 ngram;11-14 行:遍历每个样本已生成的 ngram,调用 _get_generated_ngrams 办法获取以后步每个样本须要禁止生成的 token,最初返回。
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
    # Before decoding the next token, prevent decoding of ngrams that have already appeared
    start_idx = cur_len + 1 - ngram_size
    ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
    return banned_ngrams.get(ngram_idx, [])
2- 5 行:start_idx 为已生成序列中最初一个 ngram 的起始地位,cur_len 为已生成序列中最初一个 ngram 除最初一个 token 外的完结地位,因而 prev_input_ids[start_idx: curlen] 即为最初一个 ngram 的前缀,用该前缀去 banned_grams 查找,若存在则取得以后步须要禁止生成的 token,否则为空。最初返回后果;20-23 行:遍历所有被禁止生成的 token,将其得分赋值为 -inf;43-59 行:与 greedy search 雷同;61-63 行:对 next_token_scores 进行维度变换,[batch_size num_beams, vocab_size] -> [batch_size, num_beams vocab_size];65-68 行:获取 score 最高的 2 num_beams 个预测 token 和其得分,留神 next_token_scores 的维度为[batch_size num_beams],在生成第一个 token 时,因为 1)中的设置,除第一条候选门路外的其余门路分数均为 -1e9,因而只会从第一条候选门路中取出 2 num_beams 个后果,在生成后续 token 时,就将是从所有候选门路中去取了,这其实是一种边界解决的小技巧,可能应用雷同的代码去解决第一次解码和后续解码;70-71 行:next_indices 为候选门路的下标,表明该预测 token 属于哪条候选门路,next_tokens 为预测 token 的 id;73-82 行:调用 beam_scorer.process 办法,获取 beam search 的后果。

代码:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

def process(
    self,
    input_ids: torch.LongTensor,
    next_scores: torch.FloatTensor,
    next_tokens: torch.LongTensor,
    next_indices: torch.LongTensor,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
    cur_len = input_ids.shape[-1]
    batch_size = len(self._beam_hyps)
    if not (batch_size == (input_ids.shape[0] // self.group_size)):
        if self.num_beam_groups > 1:
            raise ValueError(f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam"
                f"size of {self.group_size} is expected by the beam scorer."
            )
        else:
            raise ValueError(f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of"
                f"{self.group_size} is expected by the beam scorer."
            )

    device = input_ids.device
    next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
    next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
    next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]

    for batch_idx, beam_hyp in enumerate(self._beam_hyps):
        if self._done[batch_idx]:
            if self.num_beams < len(beam_hyp):
                raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
            if eos_token_id is None or pad_token_id is None:
                raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
            # pad the batch
            next_beam_scores[batch_idx, :] = 0
            next_beam_tokens[batch_idx, :] = pad_token_id
            next_beam_indices[batch_idx, :] = 0
            continue

        # next tokens for this sentence
        beam_idx = 0
        for beam_token_rank, (next_token, next_score, next_index) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
        ):
            batch_beam_idx = batch_idx * self.group_size + next_index
            # add to generated hypotheses if end of sentence
            if (eos_token_id is not None) and (next_token.item() in eos_token_id):
                # if beam_token does not belong to top num_beams tokens, it should not be added
                is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
                if is_beam_token_worse_than_top_num_beams:
                    continue
                if beam_indices is not None:
                    beam_index = beam_indices[batch_beam_idx]
                    beam_index = beam_index + (batch_beam_idx,)
                else:
                    beam_index = None

                beam_hyp.add(input_ids[batch_beam_idx].clone(),
                    next_score.item(),
                    beam_indices=beam_index,
                )
            else:
                # add next predicted token since it is not eos_token
                next_beam_scores[batch_idx, beam_idx] = next_score
                next_beam_tokens[batch_idx, beam_idx] = next_token
                next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
                beam_idx += 1

            # once the beam for next step is full, don't add more tokens to it.
            if beam_idx == self.group_size:
                break

        if beam_idx < self.group_size:
            raise ValueError(f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
                f"{eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
            )

        # Check if we are done so that we can save a pad step if all(done)
        self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(next_scores[batch_idx].max().item(), cur_len
        )

    return UserDict(
        {"next_beam_scores": next_beam_scores.view(-1),
            "next_beam_tokens": next_beam_tokens.view(-1),
            "next_beam_indices": next_beam_indices.view(-1),
        }
    )
11-23 行:参数查看,要求 batch_size 必须等于 input_ids.shape[0] self.group_size,self._beam_hyps 保留 batch 内每条样本所有候选门路的解码后果,长度为 batch_size num_beams,self.group_size 在此处等于 num_beams,后续遇到时用 num_beams 来代替,在另一种解码策略 group beam search 中会再进行具体介绍;25-28 行:next_beam_tokens 为以后步预测的 token,next_beam_scores 为预测 token 对应的门路的得分,next_beam_indices 为预测 token 所在门路的下标,维度均为 [batch_size, 2 num_beams];30-31 行:与 greedy search 雷同;33-33 行:遍历 batch 内每个样本已生成的句子;34-43 行:若以后样本已解码实现,首先进行参数查看,已生成的句子个数不能小于 num_beams,eos_token_id 和 pad_token_id 不能同时为 None。因为已解码实现,所以将以后步预测 token 设为 pad token,对应的门路的得分和所在门路的下标设为 0,这里能够设为 0 的起因是解码实现后,门路得分已存在 self._beam_hyps 中;45-49 行:遍历以后样本在以后步预测的 2 num_beams 个 token,以及其门路的得分和所在门路的下标;50-50 行:batch_beam_idx 为预测 token 在 batch 中的下标;51-67 行:若以后步预测的 token 在 eos_token 中,阐明已解码实现,须要将其退出以后样本的生成后果中。首先,若 beam_token_rank 大于等于 num_beams,因为 score 是通过 log_softmax 运算失去的,是一个正数,因而后续不会再有门路的得分会大于以后步的前 num_beams 个门路的得分了,因而不须要再将该后果退出生成后果之中了。之后,beam_indices 为每个样本最初一个预测的 token 所在门路的每一步门路下标,是一个大小为 batch_size* num_beams 的元组,其中每个元素也是一个元组,若其不为空,则将以后步预测的 token 所在的门路退出对应的元组中;63-67 行:beam_hyp 用来存储以后样本的所有生成后果,若执行到该处,则将以后生成的后果退出该样本的 beam_hyp 中。

代码:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
    """Add a new hypothesis to the list."""
    score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
    if len(self) < self.num_beams or score > self.worst_score:
        self.beams.append((score, hyp, beam_indices))
        if len(self) > self.num_beams:
            sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
            del self.beams[sorted_next_scores[0][1]]
            self.worst_score = sorted_next_scores[1][0]
        else:
            self.worst_score = min(score, self.worst_score)
5- 5 行:计算 score,将所有生成的 token 的 logsoftmax 的值相加,再除以(长度 self.length_penalty),这个 score 也作为这条门路的最终得分,这里除以(长度 self.length_penalty)次要是为了减少或缩小长度更长的序列的得分,当 self.length_penalty > 0 的时候,这一步的计算就会减少长度更长的序列的得分,self.length_penalty < 0 的时候反之;能够通过几个例子来看:
eg1:假如 self.length_penalty = 0
序列 1:今天天气很好(长度 6,sum_logprobs=-0.6)那么 score1 = -0.6 / 6 ** 0 = -0.6 / 1 = -0.6
序列 2:今天天气真的真的很好(长度 10,sum_logprobs=-0.8)那么 score2 = -0.8 / 10 ** 0 = -0.8 / 1 = -0.8
此时 score1 > score2,最终会抉择长度更短的序列 1

eg2:假如 self.length_penalty = 1
序列 1:今天天气很好(长度 6,sum_logprobs=-0.6)那么 score1 = -0.6 / 6 ** 1 = -0.6 / 6 = -0.1
序列 2:今天天气真的真的很好(长度 10,sum_logprobs=-0.8)那么 score2 = -0.8 / 10 ** 1 = -0.8 / 10 = -0.08
此时 score2 > score1,最终会抉择长度更长的序列 2

eg3:假如 self.length_penalty = 2
候选 1:今天天气很好(长度 6,sum_logprobs=-0.6)那么 score1 = -0.6 / 6 ** 2 = -0.6 / 36 = -0.017
候选 2:今天天气真的真的很好(长度 10,sum_logprobs=-0.8)那么 score2 = -0.8 / 10 ** 2 = -0.8 / 100 = -0.008
此时 score2 > score1,最终也会抉择长度更长的序列 2,但能够发现相比二、score2 和 score1 的差值更大了,也就是说当 self.length_penalty > 0 的时候,其值越大,对长度更长的序列的得分减少的越多。
6-13 行:若已生成的序列个数小于 num_beams 或以后门路得分大于之前生成的序列的最差得分,则将其退出 self.beams 中,存储得分,token 序列和所在门路。若退出后已生成的序列个数大于 num_beams,按得分对 self.beams 进行升序排序,去除得分最低的第一个序列,并更新最差得分,否则间接更新最差得分。若以后步预测 token 不在 eos_token 中,则将其得分、token_id 和所在门路退出以后样本的候选之中。beam_idx 为以后样本已生成的候选个数;75-77 行:若以后样本已生成的候选个数等于 num_beams,则完结循环;79-83 行:安全检查,已生成的候选个数若小于 num_beams,则抛出异样,这种异样在以后步预测的 2 * num_beams 个 token 有 num_beams + 1 个以上呈现在 eos_token 中的状况下可能呈现;85-88 行:判断以后样本是否已解码实现。

代码:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
    """
    If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
    one in the heap, then we are done with this sentence.
    """

    if len(self) < self.num_beams:
        return False
    elif self.early_stopping:
        return True
    else:
        cur_score = best_sum_logprobs / cur_len**self.length_penalty
        ret = self.worst_score >= cur_score
        return ret
7- 8 行:若已生成序列个数小于 num_beams,返回 False;否则,若设置了提前进行,则返回 True;否则,判断已生成序列的最差得分是否大于等于以后步得分最高的序列的得分,若大于等于则返回 True,否则返回 False。其中 False 示意未解码实现,True 示意已解码实现;返回以后步预测的 token,预测 token 对应的门路的得分和预测 token 所在门路的下标;84-86 行:从输入中获取以后步预测的 token,预测 token 对应的门路的得分和预测 token 所在门路的下标;88-88 行:更新 input_ids,即已生成的序列,将以后预测的 token 拼接到之前预测的序列之后,其中 input_ids[beam_idx, :] 示意通过所在门路的下标取出该门路已生成的 token 序列;90-94 行:更新 model_kwargs,用于下一次预测。若须要缓存已生成序列的 key-value 和 cross key-value,则依据 beam_idx 对其进行重排序,这是因为每一步预测的 token 所在的门路可能不一样,因而须要选出这些门路对应的 key value 进行缓存;96-97 行:将预测 token 以后所在的门路下标与该门路之前存储的门路下标进行拼接;99-106 行:与 greedy search 雷同;108-117 行:从候选中选出最终须要返回的后果。

代码:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

def finalize(
    self,
    input_ids: torch.LongTensor,
    final_beam_scores: torch.FloatTensor,
    final_beam_tokens: torch.LongTensor,
    final_beam_indices: torch.LongTensor,
    max_length: int,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
    batch_size = len(self._beam_hyps)

    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]

    # finalize all open beam hypotheses and add to generated hypotheses
    for batch_idx, beam_hyp in enumerate(self._beam_hyps):
        if self._done[batch_idx]:
            continue

        # all open beam hypotheses are added to the beam hypothesis
        # beam hypothesis class automatically keeps the best beams
        for beam_id in range(self.num_beams):
            batch_beam_idx = batch_idx * self.num_beams + beam_id
            final_score = final_beam_scores[batch_beam_idx].item()
            final_tokens = input_ids[batch_beam_idx]
            beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
            beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)

    # select the best hypotheses
    sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
    best = []
    best_indices = []
    best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)

    # retrieve best hypotheses
    for i, beam_hyp in enumerate(self._beam_hyps):
        sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
        for j in range(self.num_beam_hyps_to_keep):
            best_hyp_tuple = sorted_hyps.pop()
            best_score = best_hyp_tuple[0]
            best_hyp = best_hyp_tuple[1]
            best_index = best_hyp_tuple[2]
            sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

            # append hyp to lists
            best.append(best_hyp)

            # append indices to list
            best_indices.append(best_index)

            best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

    # prepare for adding eos
    sent_lengths_max = sent_lengths.max().item() + 1
    sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
    decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)

    if len(best_indices) > 0 and best_indices[0] is not None:
        indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
    else:
        indices = None

    # shorter batches are padded if needed
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`pad_token_id` has to be defined"
        decoded.fill_(pad_token_id)

    if indices is not None:
        indices.fill_(-1)

    # fill with hypotheses and eos_token_id if the latter fits in
    for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
        decoded[i, : sent_lengths[i]] = hypo

        if indices is not None:
            indices[i, : len(best_idx)] = torch.tensor(best_idx)

        if sent_lengths[i] < sent_max_len:
            # inserting only the first eos_token_id
            decoded[i, sent_lengths[i]] = eos_token_id[0]

    return UserDict(
        {
            "sequences": decoded,
            "sequence_scores": best_scores,
            "beam_indices": indices,
        }
    )
12-15 行:与 greedy search 雷同;17-18 行:遍历每个样本生成的后果;19-29 行:若以后样本已实现解码,则跳过。否则将最初一步的生成的所有候选序列都退出到以后样本的生成后果中;31-35 行:self.num_beam_hyps_to_keep 为每个样本须要返回的序列个数,因而 sent_lengths 和 best_scores 别离用于存储最终返回的所有序列的长度和得分,best 用于存储最终返回的所有序列,best_indices 用于存储最终返回的所有序列在每一步抉择的门路下标;37-38 行:遍历每个样本生成的后果;39-39 行:按得分对每个候选序列进行升序排序;40-53 行:遍历 self.num_beam_hyps_to_keep 次,每次从开端弹出一个序列。best_score 为该序列的总得分,best_token 为该序列的所有 token_id,best_index 为该序列每一步抉择的门路下标。更新 sent_lengths、best、best_indices、best_scores;55-58 行:计算序列的最大长度,将以后序列的最大长度 + 1,示意 eos_token 也占一位。max_length 为预设的序列最大长度,最终序列的最长度取以后已生成序列的最大长度和预设的最大长度的最小值。decoded 为最终返回的所有序列,相比 best,其所有序列的长度均为 sent_max_len;60-63 行:indices 为所有序列在每一步抉择的门路下标,同样,相比 best_indices,其长度均为 sent_max_len;65-68 行:若以后已生成序列的最小长度和最大长度不相等,则将 decoded 的值全副填充为 pad_token_id;70-71 行:将 indices 的值全副填充为 -1;73-74 行:遍历所有已生成的序列和其每一步抉择的门路下标;75-75 行:sent_length[i] 为以后序列的长度,将 decoded 的前 sent_length[i] 个 token 用以后序列填充;77-78 行:对 indices 进行填充;80-82 行:将第 sent_length[i] 位填充为 eos_token 84-90 行:返回最终的生成的所有序列、所有序列的得分、所有序列在每一步抉择的门路下标。

3.3.3 解码完结,返回后果

  if return_dict_in_generate:
        if not output_scores:
            sequence_outputs["sequence_scores"] = None

        if self.config.is_encoder_decoder:
            return BeamSearchEncoderDecoderOutput(sequences=sequence_outputs["sequences"],
                sequences_scores=sequence_outputs["sequence_scores"],
                scores=scores,
                beam_indices=sequence_outputs["beam_indices"],
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
            )
        else:
            return BeamSearchDecoderOnlyOutput(sequences=sequence_outputs["sequences"],
                sequences_scores=sequence_outputs["sequence_scores"],
                scores=scores,
                beam_indices=sequence_outputs["beam_indices"],
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
            )
    else:
        return sequence_outputs["sequences"]

这一步的逻辑与 greedy search 基本一致;

3.4 整体流程

04、sample

4.1 原理介绍

4.1.1 Random sampling

随机采样策略依据以后的概率来抽签抉择 next token,即

。如上图所示,任何词都有肯定概率被抉择。该计划生成的序列充斥了创造性,也绝对较少呈现反复序列循环问题。然而它生成的语句却很可能不通顺。

这里个别会引入 temperature,来扭转生成 next token 的概率分布,使其更偏差于 high probability token。具体做法是在 softmax 中引入 t,取值范畴(0, 1]。t 趋近于 0,就变成了 greedy search。通过调整 t 的大小,能够防止 sample from tail distribution。

4.1.2 Top-k sampling

Fan et. al (2018) 提出了 Top-K 采样策略。该策略会在采样之前缩减采样空间,只保留概率最高的 k 个词,而后从新进行归一化失去新的概率分布。比方上图中,取 k=6,则只在 6 个词中进行采样,这 6 个词总概率有可能不高(左图),但也可能十分靠近 1(右图)。这会造成两个问题:

a. 左图中的 people, big, house 等词实际上可能是正当的输入,然而却不在候选里,这就限度了模型的创造性和多样性。

b. 右图中,down, a 的概率很小,然而仍被放在了候选中,这就有可能让模型输入不通顺的垃圾信息。

4.1.3 Top-p (Nucleus) sampling

为了解决上述 top-k 采样的问题,Holtzman et al. (2019) 提出了 top-p 采样策略(nucleus sampling)。给定一个概率阈值 p,从解码词候选集中抉择一个最小集 Vp,使得它们呈现的概率和大于等于 p。而后再对 Vp 做一次 re-scaling,本工夫步仅从 Vp 汇合中解码。

比方上图中,将阈值 p 设为 0.9,左图会从 9 个候选词中筛选,右图会从 3 个候选词中筛选。

从实质上看,Top-p Sampling 和 Top-k Sampling 都是从放大的候选 token 汇合中 sample token,区别在于如何放大候选汇合。在理论应用中,top-k 和 top-p 有时也会同时应用,来防止采样到非常低概率的词,同时保障后果的多样性。

从上表中能够看出,top-p(nucleus)策略的后果是与 human 后果最相近的。并且有 较低的反复率 repetition%

4.2 疾速上手

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    TopKLogitsWarper,
    TopPLogitsWarper,
    TemperatureLogitsWarper,
    StoppingCriteriaList,
    MaxLengthCriteria,
)
import torch

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

# set pad_token_id to eos_token_id because GPT2 does not have a EOS token
model.config.pad_token_id = model.config.eos_token_id
model.generation_config.pad_token_id = model.config.eos_token_id

input_prompt = "Today is a beautiful day, and"
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

# instantiate logits processors
logits_processor = LogitsProcessorList(
    [MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
    ]
)
# instantiate logits processors
logits_warper = LogitsProcessorList(
    [TopKLogitsWarper(50),
        TopPLogitsWarper(0.9)
    ]
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

torch.manual_seed(0)
outputs = model.sample(
    input_ids,
    logits_processor=logits_processor,
    logits_warper=logits_warper,
    stopping_criteria=stopping_criteria,
)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)
-------------------------------------------------output-------------------------------------------------
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']

4.3 代码解读

次要针对疾速上手的第 41-46 行代码调用的 sample 办法进行解读.

代码地址:

transformers/utils.py at v4.26.1 · huggingface/transformers · GitHub

4.3.1 根本设置,对后续须要应用的变量进行初始化

logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()

这一步与 greedy search 基本相同,惟一区别在于初始化了一个 logits_warper;

4.3.2 从 bos_token 开始解码

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

this_peer_finished = False  # used by synced_gpus only
# auto-regressive generation
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    # prepare model inputs
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

    # forward pass to get next token
    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]

    # pre-process distribution
    next_token_scores = logits_processor(input_ids, next_token_logits)
    next_token_scores = logits_warper(input_ids, next_token_scores)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_token_scores,)
        if output_attentions:
            decoder_attentions += ((outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += ((outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # sample
    probs = nn.functional.softmax(next_token_scores, dim=-1)
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

    # finished sentences should have their next token be a padding token
    if eos_token_id is not None:
        if pad_token_id is None:
            raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

    # update generated ids, model inputs, and length for next step
    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
    model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder)

    # if eos_token was found in one sentence, set sentence to finished
    if eos_token_id is not None:
        unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())

    # stop when each sentence is finished, or if we exceed the maximum length
    if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True
1-34 行:与 greedy search 雷同;35-35 行:依据采样形式对 next_token_scores 进行预处理,logits_wraper 同样为 LogitsProcessorList 的实例,会循环调用 LogitsProcessor 中的 processor,这里即为 wraper。

这里介绍疾速上手中应用的两个采样办法 top-k 和 top-p 对应的 wraper。

top-k

代码:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

class TopKLogitsWarper(LogitsWarper):
    r"""
    [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
    Args:
        top_k (`int`):
            The number of highest probability vocabulary tokens to keep for top-k-filtering.
        filter_value (`float`, *optional*, defaults to `-float("Inf")`):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    """def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        if not isinstance(top_k, int) or top_k <= 0:
            raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")

        self.top_k = max(top_k, min_tokens_to_keep)
        self.filter_value = filter_value

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        top_k = min(self.top_k, scores.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores
21-21 行:top_k 参数查看,scores 的维度为 [batch_size, vocab_size],将 top_k 赋值为预设的 top-k 和 vocab_size 的最小值;22-23 行:判断每个 token 是否须要移除,torch.topk(scores, top_k) 的后果为前 top_k 的 scores 和对应的 indices,torch.topk(scores, top_k)[0] 即前 top_k 的 scores,top_k scores 是升序排列,因而 torch.topk(scores, top_k)0 即为前 top_k 个 scores 中的最小值,最初通过 scores 是否小于该最小值来取得须要移除的下标,小于则须要移除,值为 True,否则不须要移除,值为 False;24-25 行:将须要移除的 token 的 score 赋值为 inf。最初返回预处理后的 scores。

top-p

代码:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

class TopPLogitsWarper(LogitsWarper):
    """
    [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
    Args:
        top_p (`float`):
            If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
            higher are kept for generation.
        filter_value (`float`, *optional*, defaults to `-float("Inf")`):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    """def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        top_p = float(top_p)
        if top_p < 0 or top_p > 1.0:
            raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")

        self.top_p = top_p
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        sorted_logits, sorted_indices = torch.sort(scores, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
        if self.min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep
            sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores
24-24 行:对 scores 进行升序排序,取得 sorted_logits 和 sorted_indices,维度均为 [batch_size, vocab_size],即排序后的 logits 和对应在词表中的下标;25-25 行:对 sorted_logits 进行 softmax 归一化,获取每个 token 的预测概率值。之后计算 vocab_size 这一维度的累计和,举例来说,对于第一列,值不变,对于第二列,值为第一列和第二列的值相加,对于第三列,值为第一列、第二列和第三列的值相加,以此类推;27-28 行:获取须要移除的 token 的下标,即累计概率小于 1 – top_p 的列;29-31 行:若起码须要生成的 token 个数大于 1,则将须要 sorted_indices_to_remove 的最初 self.min_tokens_to_keep 列从新赋值为 0,示意这些列不移除;33-34 行:因为 sorted_indices_to_remove 是针对 sorted_indices 的,即此时须要移除的下标的并不是 vocab_size 中对应的下标,其值才对应真正须要移除的列,因而通过 scatter 来获取真正须要移除的 token 下标。35-36 行:将对应地位的 scores 赋值为 inf。最初返回预处理后的 scores;37-53 行:与 greedy search 雷同;55-57 行:对 next_token_scores 计算概率值。依据概率值进行不放回采样,采样一个 token 作为预测 token;59-80 行:与 greedy search 雷同。

4.3.3 解码完结,返回后果

if return_dict_in_generate:
    if self.config.is_encoder_decoder:
        return SampleEncoderDecoderOutput(
            sequences=input_ids,
            scores=scores,
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
        )
    else:
        return SampleDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
else:
    return input_ids

这一步的逻辑与 greedy search 基本一致;

4.4 整体流程

整体流程如上面的时序图所示:

05、sample and rank & beam sample

5.1 原理介绍

Adiwardana et al., 2020 提出了 sample-and-rank 解码策略,该办法在对话畛域成果很好。其思维是先通过 random sampling(联合 temperature 调整概率分布)生成出 N 个 sentence,而后再从 n 个 sentence 中抉择概率乘积最大的。

这种形式通过 random sampling 保留了生成后果的多样性和创造性,后又通过 rank 过滤掉了不通顺的序列。上面两个表格比照了 sample 的后果和 beam search 的后果。显著地,sample 后果多样性会更好。

beam sample 办法是 sample and rank 的改良,原理上相似,相比 sample and rank 在最初才对后果排序去获得最佳的 n 个序列,beam sample 在每一步保留以后最佳的 n 个序列 ,既保证了多样性和创造性,又能够 缩小在 rank 阶段须要过滤掉的句子

5.2 疾速上手

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    LogitsProcessorList,
    TopKLogitsWarper,
    TopPLogitsWarper,
    BeamSearchScorer,
)
import torch

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id

# add encoder_outputs to model keyword arguments
model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
    )
}

# instantiate beam scorer
beam_scorer = BeamSearchScorer(
    batch_size=1,
    max_length=model.config.max_length,
    num_beams=num_beams,
    device=model.device,
)

# instantiate logits processors
logits_warper = LogitsProcessorList(
    [TopKLogitsWarper(50),
        TopPLogitsWarper(0.9),
    ]
)

outputs = model.beam_sample(input_ids, beam_scorer, logits_warper=logits_warper, **model_kwargs)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)
-------------------------------------------------output-------------------------------------------------
['Wie alt bist du?']

5.3 代码解读

次要针对疾速上手的第 46-48 行代码调用的 beam_sample 办法进行解读。

代码地址:transformers/utils.py at ae54e3c3b18bac0832ad62ea9b896dfd52a09850 · huggingface/transformers · GitHub

5.3.1 根本设置,对后续须要应用的变量进行初始化

这一步与 beam search 雷同。

5.3.2 从 bos_token 开始解码

beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False  # used by synced_gpus only
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        cur_len = cur_len + 1
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]

    # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
    # cannot be generated both before and after the `nn.functional.log_softmax` operation.
    next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
    next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)

    next_token_scores_processed = logits_processor(input_ids, next_token_scores)
    next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
    next_token_scores = logits_warper(input_ids, next_token_scores)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (logits_warper(input_ids, next_token_scores_processed),)
        if output_attentions:
            decoder_attentions += ((outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += ((outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # reshape for beam search
    vocab_size = next_token_scores.shape[-1]
    next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

    probs = nn.functional.softmax(next_token_scores, dim=-1)

    next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
    next_token_scores = torch.gather(next_token_scores, -1, next_tokens)

    next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
    next_tokens = torch.gather(next_tokens, -1, _indices)

    next_indices = torch_int_div(next_tokens, vocab_size)
    next_tokens = next_tokens % vocab_size

    # stateless
    beam_outputs = beam_scorer.process(
        input_ids,
        next_token_scores,
        next_tokens,
        next_indices,
        pad_token_id=pad_token_id,
        eos_token_id=eos_token_id,
        beam_indices=beam_indices,
    )
    beam_scores = beam_outputs["next_beam_scores"]
    beam_next_tokens = beam_outputs["next_beam_tokens"]
    beam_idx = beam_outputs["next_beam_indices"]

    input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

    model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder)
    if model_kwargs["past_key_values"] is not None:
        model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)

    if return_dict_in_generate and output_scores:
        beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

    # increase cur_len
    cur_len = cur_len + 1

    if beam_scorer.is_done or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
    input_ids,
    beam_scores,
    next_tokens,
    next_indices,
    pad_token_id=pad_token_id,
    eos_token_id=eos_token_id,
    max_length=stopping_criteria.max_length,
    beam_indices=beam_indices,
)
11-39 行:与 beam search 基本一致;40-40 行:依据采样形式对 next_token_scores 进行预处理,logits_wrapper 为 LogitsProcessorList 的实例,已在 sample 中具体介绍;42-62 行:与 beam search 基本一致;64-70 行:这几行代码做的事件便是 sample and rank 中的 sample,首先对 next_token_scores 计算概率值,依据概率值进行不放回采样,采样 2 * num_beams 个 token 作为候选预测 token。之后依据 token id 去 gather 失去 token 对应的得分。因为采样失去的 token 不肯定是按得分降序排序的,所以须要对 next_token_scores 降序排序,再依据 indices 去 gather 失去对应的 token,保障 token 是按得分降序排序的。72-118 行:与 beam search 基本一致。

5.3.3 解码完结,返回后果

if return_dict_in_generate:
    if not output_scores:
        sequence_outputs["sequence_scores"] = None

    if self.config.is_encoder_decoder:
        return BeamSampleEncoderDecoderOutput(sequences=sequence_outputs["sequences"],
            sequences_scores=sequence_outputs["sequence_scores"],
            scores=scores,
            beam_indices=sequence_outputs["beam_indices"],
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
        )
    else:
        return BeamSampleDecoderOnlyOutput(sequences=sequence_outputs["sequences"],
            sequences_scores=sequence_outputs["sequence_scores"],
            scores=scores,
            beam_indices=sequence_outputs["beam_indices"],
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
else:
    return sequence_outputs["sequences"]

这一步的逻辑与 greedy search 基本一致;

5.4 整体流程

整体流程如上面的时序图所示:

06、group beam search

6.1 原理介绍

group beam search 同样是为了解决 beam search 多样性有余的问题,如上图所示,能够发现 beam search 生成的图像形容简直是反复的,这是因为在搜寻树中具备类似的共享门路,导致最终的变动很小。相比之下,group(diverse) beam search 生成的后果则更多样化,也更加相似形容图像的人际差别。

group beam search 次要思路是通过将 beam search 中的候选门路进行分组,在各组内去寻找最优解 。如上图所示,beam search 的候选门路有 6 条,group beam search 将这 6 条候选门路两两作为一组,分为三组。每一步都在各组内的词表空间上来取 top-2 的后果作为以后预测的 token,对于以后组来说,通过 对之前组已生成的 token 进行惩办 ,来保障以后组生成的 token 与之前组不同的概率更大,从而更具 多样性

6.2 疾速上手

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    HammingDiversityLogitsProcessor,
    BeamSearchScorer,
)
import torch

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


# lets run diverse beam search using 6 beams
num_beams = 6
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id

# add encoder_outputs to model keyword arguments
model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
    )
}

# instantiate beam scorer
beam_scorer = BeamSearchScorer(
    batch_size=1,
    max_length=model.config.max_length,
    num_beams=num_beams,
    device=model.device,
    num_beam_groups=3,
    num_beam_hyps_to_keep=2,
)

# instantiate logits processors
logits_processor = LogitsProcessorList(
    [HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),    ]
)

outputs = model.group_beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)
-------------------------------------------------output-------------------------------------------------
['Wie alt bist du?', 'Wie alt sind Sie?']

6.3 代码解读

次要针对疾速上手的第 47-49 行代码调用的 group beam search 办法进行解读。

代码地址:transformers/utils.py at ae54e3c3b18bac0832ad62ea9b896dfd52a09850 · huggingface/transformers · GitHub

6.3.1 根本设置,对后续须要应用的变量进行初始化

batch_size = len(beam_scorer._beam_hyps)num_beams = beam_scorer.num_beamsnum_beam_groups = beam_scorer.num_beam_groupsnum_sub_beams = num_beams // num_beam_groups

这一步与 beam search 基本一致,区别在于须要额定初始化一些用于 group beam search 的变量。

1- 2 行:获取 batch_size 和候选门路个数;3- 4 行:获取组的个数和组内候选门路个数。

6.3.2 从 bos_token 开始解码

# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in# the same group don't produce same tokens everytime.beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)beam_scores[:, ::num_sub_beams] = 0beam_scores = beam_scores.view((batch_size * num_beams,))this_peer_finished = False  # used by synced_gpus onlywhile True:    if synced_gpus:        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.        # The following logic allows an early break if all peers finished generating their sequence        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)        # send 0.0 if we finished, 1.0 otherwise        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)        # did all peers finish? the reduced sum will be 0.0 then        if this_peer_finished_flag.item() == 0.0:            break    # predicted tokens in cur_len step    current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)    # indices which will form the beams in the next time step    reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)    # do one decoder step on all beams of all sentences in batch    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)    outputs = self(**model_inputs,        return_dict=True,        output_attentions=output_attentions,        output_hidden_states=output_hidden_states,)    if synced_gpus and this_peer_finished:        cur_len = cur_len + 1        continue  # don't waste resources running the code we don't need    if output_scores:        processed_score = torch.zeros_like(outputs.logits[:, -1, :])    for beam_group_idx in range(num_beam_groups):        group_start_idx = beam_group_idx * num_sub_beams        group_end_idx = min(group_start_idx + num_sub_beams, num_beams)        group_size = group_end_idx - group_start_idx        # indices of beams of current group among all sentences in batch        batch_group_indices = []        for batch_idx in range(batch_size):            batch_group_indices.extend([batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]            )        group_input_ids = input_ids[batch_group_indices]        # select outputs of beams of current group only        next_token_logits = outputs.logits[batch_group_indices, -1, :]        # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`        # cannot be generated both before and after the `nn.functional.log_softmax` operation.        next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)        next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1)  # (batch_size * group_size, vocab_size)        vocab_size = next_token_scores.shape[-1]        next_token_scores_processed = logits_processor(group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx)        next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)        next_token_scores = next_token_scores.expand_as(next_token_scores_processed)        if output_scores:            processed_score[batch_group_indices] = next_token_scores_processed        # reshape for beam search        next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)        # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)        next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True)        next_indices = torch_int_div(next_tokens, vocab_size)        next_tokens = next_tokens % vocab_size        # stateless        process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None        beam_outputs = beam_scorer.process(group_input_ids,            next_token_scores,            next_tokens,            next_indices,            pad_token_id=pad_token_id,            eos_token_id=eos_token_id,            beam_indices=process_beam_indices,)        beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]        beam_next_tokens = beam_outputs["next_beam_tokens"]        beam_idx = beam_outputs["next_beam_indices"]        if return_dict_in_generate and output_scores:            beam_indices[beam_group_idx] = tuple(beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))            )        input_ids[batch_group_indices] = group_input_ids[beam_idx]        group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)        current_tokens[batch_group_indices] = group_input_ids[:, -1]        # (beam_idx // group_size) -> batch_idx        # (beam_idx % group_size) -> offset of idx inside the group        reordering_indices[batch_group_indices] = (num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)        )    # Store scores, attentions and hidden_states when required    if return_dict_in_generate:        if output_scores:            scores += (processed_score,)        if output_attentions:            decoder_attentions += ((outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)            )            if self.config.is_encoder_decoder:                cross_attentions += (outputs.cross_attentions,)        if output_hidden_states:            decoder_hidden_states += ((outputs.decoder_hidden_states,)                if self.config.is_encoder_decoder                else (outputs.hidden_states,)            )    input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)    model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder)    if model_kwargs["past_key_values"] is not None:        model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], reordering_indices        )    # increase cur_len    cur_len = cur_len + 1    if beam_scorer.is_done or stopping_criteria(input_ids, scores):        if not synced_gpus:            break        else:            this_peer_finished = Truefinal_beam_indices = sum(beam_indices, ()) if beam_indices is not None else Nonesequence_outputs = beam_scorer.finalize(input_ids,    beam_scores,    next_tokens,    next_indices,    pad_token_id=pad_token_id,    eos_token_id=eos_token_id,    max_length=stopping_criteria.max_length,    beam_indices=final_beam_indices,)
1- 5 行:初始化 beam_scores,维度为 [batch_size, num_beams],首先赋值为 -1e9,之后将第一条候选门路的门路分数均赋值为 0,含意已在 beam search 中介绍;7-17 行:与 beam search 基本一致;19-20 行:初始化 current_tokens,用于存储以后步预测的 token;22-23 行:初始化 reordering_indices,用于后续对缓存的 key value 进行重排序;25-39 行:与 beam search 基本一致;41-41 行:在组级别进行遍历;42-44 行:初始化组的地位和大小信息,beam_group_idx 示意以后是第几组,num_sub_beams 示意每组的候选门路个数,因而 group_start_idx 示意对于一个样本来说,该组在其候选门路中的起始地位,group_end_idx 为该组在其候选门路中的完结地位,左闭右开,group_size 是组的大小,即组内有多少候选门路,留神这里组的大小是针对单个样本的;46-53 行:因为每个样本的所有候选门路会被分成多个组,所以这里是在将所有样本属于该组的候选门路在 batch 内的下标退出到 batch_group_indices 中。通过下标将每个样本属于该组的候选门路从 input_ids 中取出来,退出到 group_input_ids,大小为 group_size * batch_size;55-56 行:取出该组内所有样本在以后步的 logits;58-104 行:与 beam search 基本一致,最初失去的 beam_scores 是预测 token 的得分,beam_next_tokens 是预测 token 的 id,beam_idx 是预测 token 在 group_input_ids 中的下标。须要额定介绍的是 66-67 行对 logits 的预处理,疾速上手中应用的预处理办法为 Hamming 多样性预处理办法,这个办法也只针对 group beam search 应用,作用是使得各个组生成的后果更加具备多样性;与 beam search 基本一致,最初失去的 beam_scores 是预测 token 的得分,beam_next_tokens 是预测 token 的 id,beam_idx 是预测 token 在 group_input_ids 中的下标。须要额定介绍的是 66-67 行对 logits 的预处理,疾速上手中应用的预处理办法为 Hamming 多样性预处理办法,这个办法也只针对 group beam search 应用,作用是使得各个组生成的后果更加具备多样性。

代码:transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

class HammingDiversityLogitsProcessor(LogitsProcessor):    r"""[`LogitsProcessor`] that enforces diverse beam search. Note that this logits processor is only effective for    [`PreTrainedModel.group_beam_search`]. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence    Models](https://arxiv.org/pdf/1610.02424.pdf) for more details.    Args:        diversity_penalty (`float`):            This value is subtracted from a beam's score if it generates a token same as any beam from other group at a            particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.        num_beams (`int`):            Number of beams used for group beam search. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more            details.        num_beam_groups (`int`):            Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.            See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.    """def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):        if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):            raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")        self._diversity_penalty = diversity_penalty        if not isinstance(num_beams, int) or num_beams < 2:            raise ValueError("`num_beams` should be an integer strictly larger than 1.")        self._num_beams = num_beams        if not isinstance(num_beam_groups, int) or num_beam_groups < 2:            raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")        if num_beam_groups > num_beams:            raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")        self._num_sub_beams = num_beams // num_beam_groups    def __call__(self,        input_ids: torch.LongTensor,        scores: torch.FloatTensor,        current_tokens: torch.LongTensor,        beam_group_idx: int,) -> torch.FloatTensor:        # hamming diversity: penalise using same token in current group which was used in previous groups at        # the same time step        batch_size = current_tokens.shape[0] // self._num_beams        group_start_idx = beam_group_idx * self._num_sub_beams        group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)        group_size = group_end_idx - group_start_idx        vocab_size = scores.shape[-1]        if group_start_idx == 0:            return scores        for batch_idx in range(batch_size):            # predicted tokens of last time step of previous groups            previous_group_tokens = current_tokens[batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx]            token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)            scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency        return scores
39-44 行:batch_size 为实在的样本个数。后面介绍过,group_start_idx 示意对于一个样本来说,该组在其候选门路中的起始地位,group_end_idx 为该组在其候选门路中的完结地位,左闭右开,group_size 是组的大小,vocab_size 是词表大小;46-47 行:如果以后是第:一组,则不必进行多样性惩办,因为只有在第二组的时候才须要对已生成的 token 进行惩办;49-57 行:遍历每个样本,previous_group_tokens 是以后样本上一组生成的所有 token,token_frequceny 是依据已生成 token 对词表内所有 token 计算失去的频率。之后对以后步所有已生成 token 的得分进行惩办,频率越高惩办的力度越大。最初返回惩办后的得分;106-108 行:依据 beam_idx 从 group_input_ids 中取出预测 token 已生成的序列,对 input_ids 进行更新,将 input_ids 中所有属于该组的样本的候选门路更新为以后步预测 token 的已生成序列。之后将预测 token 与其已生成序列进行拼接。将以后步预测 token 赋值给 current_tokens;110-114 行:更新 reordering_indices,torch_int_div(beam_idx, group_size)即 beam_idx // group_size,示意该预测 token 属于第几个样本,乘上 num_beams 后,即为该样本第一个候选门路在 batch 内的下标。beam % group_size 是预测 token 在该组的偏移地位,与 group_start_idx 相加即为预测 token 在候选门路中的下标。最初与该样本第一个候选门路在 batch 内的下标相加即为该预测 token 在 batch 内的下标。将该下标赋值给 reordering_indices 中 batch_group_indices 的那些地位,示意这些地位的已生成序列在该工夫步后会被映射为预测 token 对应的已生成序列,因而须要缓存这些序列的 key value;116-163 行:与 beam search 统一。

6.3.3 解码完结,返回后果

if return_dict_in_generate:    if not output_scores:        sequence_outputs["sequence_scores"] = None    if self.config.is_encoder_decoder:        return BeamSearchEncoderDecoderOutput(sequences=sequence_outputs["sequences"],            sequences_scores=sequence_outputs["sequence_scores"],            scores=scores,            beam_indices=sequence_outputs["beam_indices"],            encoder_attentions=encoder_attentions,            encoder_hidden_states=encoder_hidden_states,            decoder_attentions=decoder_attentions,            cross_attentions=cross_attentions,            decoder_hidden_states=decoder_hidden_states,        )    else:        return BeamSearchDecoderOnlyOutput(sequences=sequence_outputs["sequences"],            sequences_scores=sequence_outputs["sequence_scores"],            scores=scores,            beam_indices=sequence_outputs["beam_indices"],            attentions=decoder_attentions,            hidden_states=decoder_hidden_states,        )else:    return sequence_outputs["sequences"]

这一步的逻辑与 greedy search 基本一致;

6.4 整体流程

整体流程如上面的时序图所示:

07、总结

通过后面的介绍,置信大家曾经发现了,各种解码策略无非是通过调整 logits(即模型对每个 token 的预测得分)和 batch_size,来取得不同的生成后果。

对 logits 做调整个别又可分为是用于预处理还是采样,对用于预处理的最小长度、反复词惩办这些性能,形象出基类 Processor 类,对用于采样的 top-k、top-p 这些性能,形象出基类 Warper。而所有对 logits 做调整的性能类都能够又退出到 LogitsProcessList,组成一个 pipeline,每次想用哪一个对其进行初始化并退出即可。

对 batch_size 做调整次要在须要生成多个候选或是须要返回多个后果的状况下,对于 beam search 系列的解码策略,通过将 batch_size 扩充候选门路的个数倍,来取得不同的候选序列。对 sample 系列的解码策略,通过将 batch_size 扩充返回后果个数倍,来采样失去不同的后果。

08、支流模型计划

以上计划被支流模型所采纳。上面表格列举了从公开论文中梳理出的解码计划:

模型 解码策略 备注
GPT-2(OpenAI) greedy decoding 浏览了解工作和翻译工作
GPT-3(OpenAI) top-p sampling temperature=1, p=0.9
Meena (Google) sample-and-rank N=20,temperature=0.88,random sampling
LaMDA (Google) sample-and-rank N=16,temperature=1,top-k sampling, k=40
LLaMA (Meta) greedy decoding Question Answering 工作,其它工作不明

以上就是本篇文章的全副分享,看完文章的开发者能够珍藏一下,跟着文章步骤实机进行操作。

参考文献

Holtzman A, Buys J, Du L, et al. The curious case of neural text degeneration[J]. arXiv preprint arXiv:1904.09751, 2019.

Fan A, Lewis M, Dauphin Y. Hierarchical neural story generation[J]. arXiv preprint arXiv:1805.04833, 2018.

Adiwardana D, Luong M T, So D R, et al. Towards a human-like open-domain chatbot[J]. arXiv preprint arXiv:2001.09977, 2020.

Radford A, Wu J, Child R, et al. Language models are unsupervised multitask learners[J]. OpenAI blog, 2019, 1(8): 9.

Brown T, Mann B, Ryder N, et al. Language models are few-shot learners[J]. Advances in neural information processing systems, 2020, 33: 1877-1901.

Thoppilan R, De Freitas D, Hall J, et al. Lamda: Language models for dialog applications[J]. arXiv preprint arXiv:2201.08239, 2022.

Touvron H, Lavril T, Izacard G, et al. LLaMA: Open and Efficient Foundation Language Models[J]. arXiv preprint arXiv:2302.13971, 2023.

Ashwin K V, Michael C, et al. diverse beam search: decoding diverse soulutions from neural sequence models[J]. arXiv preprint arXiv:1610.02424, 2016.

各位开发者能够在腾讯云开发者公众号评论区聊一聊,在本篇文章中学习到了什么?又或者有什么样的疑难?咱们将选取 1 则最有意义的分享,送出腾讯云开发者 - 手段垫 1 个(见下图)。6 月 8 日中午 12 点开奖。

退出移动版