文章起源 | 恒源云社区

原文地址 | BPE 算法详解

原文作者 | Mathor


Byte Pair Encoding

在NLP模型中,输出通常是一个句子,例如"I went to New York last week.",一句话中蕴含很多单词(token)。传统的做法是将这些单词以空格进行分隔,例如['i', 'went', 'to', 'New', 'York', 'last', 'week']。然而这种做法存在很多问题,例如模型无奈通过old, older, oldest之间的关系学到smart, smarter, smartest之间的关系。如果咱们能应用将一个token分成多个subtokens,下面的问题就能很好的解决。本文将详述目前比拟罕用的subtokens算法——BPE(Byte-Pair Encoding)

当初性能比拟好一些的NLP模型,例如GPT、BERT、RoBERTa等,在数据预处理的时候都会有WordPiece的过程,其次要的实现形式就是BPE(Byte-Pair Encoding)。具体来说,例如['loved', 'loving', 'loves']这三个单词。其实自身的语义都是"爱"的意思,然而如果咱们以词为单位,那它们就算不一样的词,在英语中不同后缀的词十分的多,就会使得词表变的很大,训练速度变慢,训练的成果也不是太好。BPE算法通过训练,可能把下面的3个单词拆分成["lov","ed","ing","es"]几局部,这样能够把词的自身的意思和时态离开,无效的缩小了词表的数量。算法流程如下:

  1. 设定最大subwords个数$V$
  2. 将所有单词拆分为单个字符,并在最初增加一个进行符</w>,同时标记出该单词呈现的次数。例如,"low"这个单词呈现了5次,那么它将会被解决为{'l o w </w>': 5}
  3. 统计每一个间断字节对的呈现频率,抉择最高频者合并成新的subword
  4. 反复第3步直到达到第1步设定的subwords词表大小或下一个最高频的字节对呈现频率为1

例如

{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

呈现最频繁的字节对是 es ,共呈现了6+3=9次,因而将它们合并

{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}

呈现最频繁的字节对是 est,共呈现了6+3=9次,因而将它们合并

{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}

呈现最频繁的字节对是 est</w> ,共呈现了6+3=9次,因而将它们合并

{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}

呈现最频繁的字节对是 lo,共呈现了5+2=7次,因而将它们合并

{'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}

呈现最频繁的字节对是 low ,共呈现了5+2=7次,因而将它们合并

{'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}

…持续迭代直到达到预设的subwords词表大小或下一个最高频的字节对呈现频率为1。这样咱们就失去了更加适合的词表,这个词表可能会呈现一些不是单词的组合,然而其自身有意义的一种模式

进行符</w>的意义在于示意subword是词后缀。举例来说:st不加</w>能够呈现在词首,如st ar;加了</w>表明改字词位于词尾,如wide st</w>,二者意义截然不同

BPE实现

import re, collectionsdef get_vocab(filename):    vocab = collections.defaultdict(int)    with open(filename, 'r', encoding='utf-8') as fhand:        for line in fhand:            words = line.strip().split()            for word in words:                vocab[' '.join(list(word)) + ' </w>'] += 1    return vocabdef get_stats(vocab):    pairs = collections.defaultdict(int)    for word, freq in vocab.items():        symbols = word.split()        for i in range(len(symbols)-1):            pairs[symbols[i],symbols[i+1]] += freq    return pairsdef merge_vocab(pair, v_in):    v_out = {}    bigram = re.escape(' '.join(pair))    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')    for word in v_in:        w_out = p.sub(''.join(pair), word)        v_out[w_out] = v_in[word]    return v_outdef get_tokens(vocab):    tokens = collections.defaultdict(int)    for word, freq in vocab.items():        word_tokens = word.split()        for token in word_tokens:            tokens[token] += freq    return tokensvocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}# Get free book from Gutenberg# wget http://www.gutenberg.org/cache/epub/16457/pg16457.txt# vocab = get_vocab('pg16457.txt')print('==========')print('Tokens Before BPE')tokens = get_tokens(vocab)print('Tokens: {}'.format(tokens))print('Number of tokens: {}'.format(len(tokens)))print('==========')num_merges = 5for i in range(num_merges):    pairs = get_stats(vocab)    if not pairs:        break    best = max(pairs, key=pairs.get)    vocab = merge_vocab(best, vocab)    print('Iter: {}'.format(i))    print('Best pair: {}'.format(best))    tokens = get_tokens(vocab)    print('Tokens: {}'.format(tokens))    print('Number of tokens: {}'.format(len(tokens)))    print('==========')

输入如下

==========Tokens Before BPETokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 17, 'r': 2, 'n': 6, 's': 9, 't': 9, 'i': 3, 'd': 3})Number of tokens: 11==========Iter: 0Best pair: ('e', 's')Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 8, 'r': 2, 'n': 6, 'es': 9, 't': 9, 'i': 3, 'd': 3})Number of tokens: 11==========Iter: 1Best pair: ('es', 't')Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 16, 'e': 8, 'r': 2, 'n': 6, 'est': 9, 'i': 3, 'd': 3})Number of tokens: 10==========Iter: 2Best pair: ('est', '</w>')Tokens: defaultdict(<class 'int'>, {'l': 7, 'o': 7, 'w': 16, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'est</w>': 9, 'i': 3, 'd': 3})Number of tokens: 10==========Iter: 3Best pair: ('l', 'o')Tokens: defaultdict(<class 'int'>, {'lo': 7, 'w': 16, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'est</w>': 9, 'i': 3, 'd': 3})Number of tokens: 9==========Iter: 4Best pair: ('lo', 'w')Tokens: defaultdict(<class 'int'>, {'low': 7, '</w>': 7, 'e': 8, 'r': 2, 'n': 6, 'w': 9, 'est</w>': 9, 'i': 3, 'd': 3})Number of tokens: 9==========

编码和解码

编码
在之前的算法中,咱们曾经失去了subword的词表,对该词表依照字符个数由多到少排序。编码时,对于每个单词,遍历排好序的子词词表寻找是否有token是以后单词的子字符串,如果有,则该token是示意单词的tokens之一

咱们从最长的token迭代到最短的token,尝试将每个单词中的子字符串替换为token。 最终,咱们将迭代所有tokens,并将所有子字符串替换为tokens。 如果依然有子字符串没被替换但所有token都已迭代结束,则将残余的子词替换为非凡token,如<unk>

例如

# 给定单词序列["the</w>", "highest</w>", "mountain</w>"]# 排好序的subword表# 长度 6         5           4        4         4       4          2["errrr</w>", "tain</w>", "moun", "est</w>", "high", "the</w>", "a</w>"]# 迭代后果"the</w>" -> ["the</w>"]"highest</w>" -> ["high", "est</w>"]"mountain</w>" -> ["moun", "tain</w>"]

解码
将所有的tokens拼在一起即可,例如

# 编码序列["the</w>", "high", "est</w>", "moun", "tain</w>"]# 解码序列"the</w> highest</w> mountain</w>"

编码和解码实现

import re, collectionsdef get_vocab(filename):    vocab = collections.defaultdict(int)    with open(filename, 'r', encoding='utf-8') as fhand:        for line in fhand:            words = line.strip().split()            for word in words:                vocab[' '.join(list(word)) + ' </w>'] += 1    return vocabdef get_stats(vocab):    pairs = collections.defaultdict(int)    for word, freq in vocab.items():        symbols = word.split()        for i in range(len(symbols)-1):            pairs[symbols[i],symbols[i+1]] += freq    return pairsdef merge_vocab(pair, v_in):    v_out = {}    bigram = re.escape(' '.join(pair))    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')    for word in v_in:        w_out = p.sub(''.join(pair), word)        v_out[w_out] = v_in[word]    return v_outdef get_tokens_from_vocab(vocab):    tokens_frequencies = collections.defaultdict(int)    vocab_tokenization = {}    for word, freq in vocab.items():        word_tokens = word.split()        for token in word_tokens:            tokens_frequencies[token] += freq        vocab_tokenization[''.join(word_tokens)] = word_tokens    return tokens_frequencies, vocab_tokenizationdef measure_token_length(token):    if token[-4:] == '</w>':        return len(token[:-4]) + 1    else:        return len(token)def tokenize_word(string, sorted_tokens, unknown_token='</u>'):        if string == '':        return []    if sorted_tokens == []:        return [unknown_token]    string_tokens = []    for i in range(len(sorted_tokens)):        token = sorted_tokens[i]        token_reg = re.escape(token.replace('.', '[.]'))        matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token_reg, string)]        if len(matched_positions) == 0:            continue        substring_end_positions = [matched_position[0] for matched_position in matched_positions]        substring_start_position = 0        for substring_end_position in substring_end_positions:            substring = string[substring_start_position:substring_end_position]            string_tokens += tokenize_word(string=substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)            string_tokens += [token]            substring_start_position = substring_end_position + len(token)        remaining_substring = string[substring_start_position:]        string_tokens += tokenize_word(string=remaining_substring, sorted_tokens=sorted_tokens[i+1:], unknown_token=unknown_token)        break    return string_tokens# vocab = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}vocab = get_vocab('pg16457.txt')print('==========')print('Tokens Before BPE')tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)print('All tokens: {}'.format(tokens_frequencies.keys()))print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))print('==========')num_merges = 10000for i in range(num_merges):    pairs = get_stats(vocab)    if not pairs:        break    best = max(pairs, key=pairs.get)    vocab = merge_vocab(best, vocab)    print('Iter: {}'.format(i))    print('Best pair: {}'.format(best))    tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)    print('All tokens: {}'.format(tokens_frequencies.keys()))    print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))    print('==========')# Let's check how tokenization will be for a known wordword_given_known = 'mountains</w>'word_given_unknown = 'Ilikeeatingapples!</w>'sorted_tokens_tuple = sorted(tokens_frequencies.items(), key=lambda item: (measure_token_length(item[0]), item[1]), reverse=True)sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]print(sorted_tokens)word_given = word_given_known print('Tokenizing word: {}...'.format(word_given))if word_given in vocab_tokenization:    print('Tokenization of the known word:')    print(vocab_tokenization[word_given])    print('Tokenization treating the known word as unknown:')    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))else:    print('Tokenizating of the unknown word:')    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))word_given = word_given_unknown print('Tokenizing word: {}...'.format(word_given))if word_given in vocab_tokenization:    print('Tokenization of the known word:')    print(vocab_tokenization[word_given])    print('Tokenization treating the known word as unknown:')    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))else:    print('Tokenizating of the unknown word:')    print(tokenize_word(string=word_given, sorted_tokens=sorted_tokens, unknown_token='</u>'))

输入如下

Tokenizing word: mountains</w>...Tokenization of the known word:['mountains</w>']Tokenization treating the known word as unknown:['mountains</w>']Tokenizing word: Ilikeeatingapples!</w>...Tokenizating of the unknown word:['I', 'like', 'ea', 'ting', 'app', 'l', 'es!</w>']