共计 5281 个字符,预计需要花费 14 分钟才能阅读完成。
作者:LogM
本文原载于 https://segmentfault.com/u/logm/articles,不允许转载~
1. 源码来源
FastText 源码:https://github.com/facebookre…
本文对应的源码版本:Commits on Jun 27 2019, 979d8a9ac99c731d653843890c2364ade0f7d9d3
FastText 论文:
[1] P. Bojanowski, E. Grave, A. Joulin, T. Mikolov, Enriching Word Vectors with Subword Information
[2] A. Joulin, E. Grave, P. Bojanowski, T. Mikolov, Bag of Tricks for Efficient Text Classification
2. 概述
FastText 的论文写的比较简单,有些细节不明白,网上也查不到,所幸直接撕源码。
FastText 的 ” 分类器 ” 功能是用的最多的,所以先从 ” 分类器的 predict” 开始挖。
3. 开撕
先看程序入口的 main
函数,ok,是调用了 predict
函数。
// 文件:src/main.cc
// 行数:403
int main(int argc, char** argv) {std::vector<std::string> args(argv, argv + argc);
if (args.size() < 2) {printUsage();
exit(EXIT_FAILURE);
}
std::string command(args[1]);
if (command == "skipgram" || command == "cbow" || command == "supervised") {train(args);
} else if (command == "test" || command == "test-label") {test(args);
} else if (command == "quantize") {quantize(args);
} else if (command == "print-word-vectors") {printWordVectors(args);
} else if (command == "print-sentence-vectors") {printSentenceVectors(args);
} else if (command == "print-ngrams") {printNgrams(args);
} else if (command == "nn") {nn(args);
} else if (command == "analogies") {analogies(args);
} else if (command == "predict" || command == "predict-prob") {predict(args); // 这句是我们想要的
} else if (command == "dump") {dump(args);
} else {printUsage();
exit(EXIT_FAILURE);
}
return 0;
}
再看 predict
函数,预处理的代码不用管,直接看 predict 的那行,调用了 FastText::predictLine
。这里注意下,这是个 while
循环,所以FastText::predictLine
这个函数每次只处理一行。
// 文件:src/main.cc
// 行数:205
void predict(const std::vector<std::string>& args) {if (args.size() < 4 || args.size() > 6) {printPredictUsage();
exit(EXIT_FAILURE);
}
int32_t k = 1;
real threshold = 0.0;
if (args.size() > 4) {k = std::stoi(args[4]);
if (args.size() == 6) {threshold = std::stof(args[5]);
}
}
bool printProb = args[1] == "predict-prob";
FastText fasttext;
fasttext.loadModel(std::string(args[2]));
std::ifstream ifs;
std::string infile(args[3]);
bool inputIsStdIn = infile == "-";
if (!inputIsStdIn) {ifs.open(infile);
if (!inputIsStdIn && !ifs.is_open()) {
std::cerr << "Input file cannot be opened!" << std::endl;
exit(EXIT_FAILURE);
}
}
std::istream& in = inputIsStdIn ? std::cin : ifs;
std::vector<std::pair<real, std::string>> predictions;
while (fasttext.predictLine(in, predictions, k, threshold)) { // 这句是重点
printPredictions(predictions, printProb, false);
}
if (ifs.is_open()) {ifs.close();
}
exit(0);
}
再看 FastText::predictLine
,注意这边有两个重点。
// 文件:src/fasttext.cc
// 行数:451
bool FastText::predictLine(
std::istream& in,
std::vector<std::pair<real, std::string>>& predictions,
int32_t k,
real threshold) const {predictions.clear();
if (in.peek() == EOF) {return false;}
std::vector<int32_t> words, labels;
dict_->getLine(in, words, labels); // 这句是第一个重点
Predictions linePredictions;
predict(k, words, linePredictions, threshold); // 这句是第二个重点
for (const auto& p : linePredictions) {
predictions.push_back(std::make_pair(std::exp(p.first), dict_->getLabel(p.second)));
}
return true;
}
先看第一个重点,getLine
函数其实是 Dictionary::getLine
,定义在src/dictionary.cc
。
这段代码的干货度还是很高的,里面有两个重点,Dictionary::addSubwords
和 Dictionary::addWordNgrams
,以后会讲。这边只要知道整个函数把读到的这一行的每个 Id(包括词语的 id,SubWords 的 Id,WordNgram 的 Id),存到了数组 words
中。
// 文件:src/dictionary.cc
// 行数:378
int32_t Dictionary::getLine(
std::istream& in,
std::vector<int32_t>& words,
std::vector<int32_t>& labels) const {
std::vector<int32_t> word_hashes;
std::string token;
int32_t ntokens = 0;
reset(in);
words.clear();
labels.clear();
while (readWord(in, token)) { // `token` 是读到的一个词语,如果读到一行的行尾,则返回 `EOF`
uint32_t h = hash(token); // 找到这个词语位于哪个 hash 桶
int32_t wid = getId(token, h); // 在 hash 桶中找到这个词语的 Id,如果负数就是没找到对应的 Id
entry_type type = wid < 0 ? getType(token) : getType(wid); // 如果没找到对应 Id,则有可能是 label,`getType` 里会处理
ntokens++;
if (type == entry_type::word) {addSubwords(words, token, wid); // 重点 1,以后会讲
word_hashes.push_back(h);
} else if (type == entry_type::label && wid >= 0) {labels.push_back(wid - nwords_);
}
if (token == EOS) {break;}
}
addWordNgrams(words, word_hashes, args_->wordNgrams); // 重点 2,以后会讲
return ntokens;
}
再来看第二个重点,FastText::predict
函数,重点是 Model::predict
函数。
// 文件:src/fasttext.cc
// 行数:437
void FastText::predict(
int32_t k,
const std::vector<int32_t>& words,
Predictions& predictions,
real threshold) const {if (words.empty()) {return;}
Model::State state(args_->dim, dict_->nlabels(), 0);
if (args_->model != model_name::sup) {throw std::invalid_argument("Model needs to be supervised for prediction!");
}
model_->predict(words, k, threshold, predictions, state); // 这句是重点
}
来到 Model::predict
,有两个重点.
其中 Loss::predict
是将 hidden 层的输出结果进行 softmax 后得到最终概率最大的 k 个类别,” 分类器的 predict” 用的是经典的 softmax,所以代码也比较简单。而如果是 ” 分类器的 train” 则涉及到 Hierarchical SoftmaxLoss
和 NegativeSamplingLoss
等一些加速手段,比较复杂,以后有机会再讲。
// 文件:src/model.cc
// 行数:53
void Model::predict(
const std::vector<int32_t>& input,
int32_t k,
real threshold,
Predictions& heap,
State& state) const {if (k == Model::kUnlimitedPredictions) {k = wo_->size(0); // output size
} else if (k <= 0) {throw std::invalid_argument("k needs to be 1 or higher!");
}
heap.reserve(k + 1);
computeHidden(input, state); // 重点 1
loss_->predict(k, threshold, heap, state); // 重点 2,以后再讲
}
我们再来看另一个重点,Model::computeHidden
函数。
Model::computeHidden
函数理解起来比较简单,注意这里的 input
就是前面的 words
,是一系列 id 组成的数组(包括词语的 id,SubWords 的 Id,WordNgram 的 Id),把这些求和,然后取平均。
当然有些小伙伴可能有点疑问,Vector::addRow
为什么是求和,这个以后再讲吧。
// 文件:src/model.cc
// 行数:43
void Model::computeHidden(const std::vector<int32_t>& input, State& state)
const {
Vector& hidden = state.hidden;
hidden.zero();
for (auto it = input.cbegin(); it != input.cend(); ++it) {hidden.addRow(*wi_, *it); // 求和
}
hidden.mul(1.0 / input.size()); // 然后取平均
}
4. 总结
至此,FastText 里面的 ” 分类器的 predict” 的大致流程讲完了,其他的,如 ” 分类器的 train” 和 ” 词向量 ” 的源码也是类似的方法来阅读。
这里面有几段代码没有详细叙述:Dictionary::addSubwords
、Dictionary::addWordNgrams
、Vector::addRow
以及 训练时 softmax 的加速
,先把坑留着,以后有时间再填。