作者:LogM

本文原载于 https://segmentfault.com/u/logm/articles ,不允许转载~

1. 源码来源

FastText 源码:https://github.com/facebookre...

本文对应的源码版本:Commits on Jun 27 2019, 979d8a9ac99c731d653843890c2364ade0f7d9d3

FastText 论文:

[1] P. Bojanowski, E. Grave, A. Joulin, T. Mikolov, Enriching Word Vectors with Subword Information

[2] A. Joulin, E. Grave, P. Bojanowski, T. Mikolov, Bag of Tricks for Efficient Text Classification

2. 概述

FastText 的论文写的比较简单,有些细节不明白,网上也查不到,所幸直接撕源码。

FastText 的"分类器"功能是用的最多的,所以先从"分类器的predict"开始挖。

3. 开撕

先看程序入口的 main 函数,ok,是调用了 predict 函数。

// 文件:src/main.cc// 行数:403int main(int argc, char** argv) {  std::vector<std::string> args(argv, argv + argc);  if (args.size() < 2) {    printUsage();    exit(EXIT_FAILURE);  }  std::string command(args[1]);  if (command == "skipgram" || command == "cbow" || command == "supervised") {    train(args);             } else if (command == "test" || command == "test-label") {    test(args);  } else if (command == "quantize") {    quantize(args);  } else if (command == "print-word-vectors") {    printWordVectors(args);  } else if (command == "print-sentence-vectors") {    printSentenceVectors(args);  } else if (command == "print-ngrams") {    printNgrams(args);  } else if (command == "nn") {    nn(args);  } else if (command == "analogies") {    analogies(args);  } else if (command == "predict" || command == "predict-prob") {    predict(args);       // 这句是我们想要的  } else if (command == "dump") {    dump(args);  } else {    printUsage();    exit(EXIT_FAILURE);  }  return 0;}

再看 predict 函数,预处理的代码不用管,直接看 predict 的那行,调用了 FastText::predictLine。这里注意下,这是个 while 循环,所以
FastText::predictLine 这个函数每次只处理一行。

// 文件:src/main.cc// 行数:205void predict(const std::vector<std::string>& args) {  if (args.size() < 4 || args.size() > 6) {    printPredictUsage();    exit(EXIT_FAILURE);  }  int32_t k = 1;  real threshold = 0.0;  if (args.size() > 4) {    k = std::stoi(args[4]);    if (args.size() == 6) {      threshold = std::stof(args[5]);    }  }  bool printProb = args[1] == "predict-prob";  FastText fasttext;  fasttext.loadModel(std::string(args[2]));  std::ifstream ifs;  std::string infile(args[3]);  bool inputIsStdIn = infile == "-";  if (!inputIsStdIn) {    ifs.open(infile);    if (!inputIsStdIn && !ifs.is_open()) {      std::cerr << "Input file cannot be opened!" << std::endl;      exit(EXIT_FAILURE);    }  }  std::istream& in = inputIsStdIn ? std::cin : ifs;  std::vector<std::pair<real, std::string>> predictions;  while (fasttext.predictLine(in, predictions, k, threshold)) {     // 这句是重点    printPredictions(predictions, printProb, false);  }  if (ifs.is_open()) {    ifs.close();  }  exit(0);}

再看 FastText::predictLine,注意这边有两个重点。

// 文件:src/fasttext.cc// 行数:451bool FastText::predictLine(    std::istream& in,    std::vector<std::pair<real, std::string>>& predictions,    int32_t k,    real threshold) const {  predictions.clear();  if (in.peek() == EOF) {    return false;  }  std::vector<int32_t> words, labels;  dict_->getLine(in, words, labels);                // 这句是第一个重点  Predictions linePredictions;  predict(k, words, linePredictions, threshold);    // 这句是第二个重点  for (const auto& p : linePredictions) {    predictions.push_back(        std::make_pair(std::exp(p.first), dict_->getLabel(p.second)));  }  return true;}

先看第一个重点,getLine 函数其实是 Dictionary::getLine,定义在src/dictionary.cc

这段代码的干货度还是很高的,里面有两个重点,Dictionary::addSubwordsDictionary::addWordNgrams,以后会讲。这边只要知道整个函数把读到的这一行的每个Id(包括词语的id,SubWords的Id,WordNgram的Id),存到了数组 words 中。

// 文件:src/dictionary.cc// 行数:378int32_t Dictionary::getLine(    std::istream& in,    std::vector<int32_t>& words,    std::vector<int32_t>& labels) const {  std::vector<int32_t> word_hashes;  std::string token;  int32_t ntokens = 0;  reset(in);  words.clear();  labels.clear();  while (readWord(in, token)) {     // `token` 是读到的一个词语,如果读到一行的行尾,则返回`EOF`    uint32_t h = hash(token);       // 找到这个词语位于哪个hash桶    int32_t wid = getId(token, h);      // 在hash桶中找到这个词语的Id,如果负数就是没找到对应的Id    entry_type type = wid < 0 ? getType(token) : getType(wid);   // 如果没找到对应Id,则有可能是label,`getType`里会处理    ntokens++;    if (type == entry_type::word) {      addSubwords(words, token, wid);   // 重点1,以后会讲      word_hashes.push_back(h);    } else if (type == entry_type::label && wid >= 0) {      labels.push_back(wid - nwords_);    }    if (token == EOS) {      break;    }  }  addWordNgrams(words, word_hashes, args_->wordNgrams);  // 重点2,以后会讲  return ntokens;}

再来看第二个重点, FastText::predict 函数,重点是 Model::predict 函数。

// 文件:src/fasttext.cc// 行数:437void FastText::predict(    int32_t k,    const std::vector<int32_t>& words,    Predictions& predictions,    real threshold) const {  if (words.empty()) {    return;  }  Model::State state(args_->dim, dict_->nlabels(), 0);  if (args_->model != model_name::sup) {    throw std::invalid_argument("Model needs to be supervised for prediction!");  }  model_->predict(words, k, threshold, predictions, state);       // 这句是重点}

来到 Model::predict,有两个重点.

其中 Loss::predict 是将 hidden 层的输出结果进行 softmax 后得到最终概率最大的k个类别,"分类器的predict" 用的是经典的softmax,所以代码也比较简单。而如果是"分类器的train" 则涉及到 Hierarchical SoftmaxLossNegativeSamplingLoss 等一些加速手段,比较复杂,以后有机会再讲。

// 文件:src/model.cc// 行数:53void Model::predict(    const std::vector<int32_t>& input,    int32_t k,    real threshold,    Predictions& heap,    State& state) const {  if (k == Model::kUnlimitedPredictions) {    k = wo_->size(0); // output size  } else if (k <= 0) {    throw std::invalid_argument("k needs to be 1 or higher!");  }  heap.reserve(k + 1);  computeHidden(input, state);      // 重点1  loss_->predict(k, threshold, heap, state);    // 重点2,以后再讲}

我们再来看另一个重点,Model::computeHidden 函数。

Model::computeHidden 函数理解起来比较简单,注意这里的 input 就是前面的 words,是一系列id组成的数组(包括词语的id,SubWords的Id,WordNgram的Id),把这些求和,然后取平均。

当然有些小伙伴可能有点疑问,Vector::addRow 为什么是求和,这个以后再讲吧。

// 文件:src/model.cc// 行数:43void Model::computeHidden(const std::vector<int32_t>& input, State& state)    const {  Vector& hidden = state.hidden;  hidden.zero();  for (auto it = input.cbegin(); it != input.cend(); ++it) {    hidden.addRow(*wi_, *it);           // 求和  }  hidden.mul(1.0 / input.size());       // 然后取平均}

4. 总结

至此,FastText里面的"分类器的predict"的大致流程讲完了,其他的,如"分类器的train"和"词向量"的源码也是类似的方法来阅读。

这里面有几段代码没有详细叙述:Dictionary::addSubwordsDictionary::addWordNgramsVector::addRow以及训练时softmax的加速,先把坑留着,以后有时间再填。