6 #include "historybigram.h" 22 #include <string_view> 23 #include <unordered_set> 26 #include <fcitx-utils/macros.h> 27 #include <fcitx-utils/stringutils.h> 28 #include "constants.h" 32 #include "zstdfilter.h" 38 using WordWithCode = HistoryBigram::WordWithCode;
39 using WordWithCodeView = HistoryBigram::WordWithCodeView;
41 constexpr uint32_t historyBinaryFormatMagic = 0x000fc315;
42 constexpr uint32_t historyBinaryFormatVersion = 0x4;
43 constexpr
char bigramSeparator =
'\x01';
44 constexpr
char wordCodeSeparator =
'\x02';
46 std::string wordWithCodeToString(WordWithCodeView wordAndCode) {
47 std::string s{std::get<0>(wordAndCode)};
51 auto code = std::get<1>(wordAndCode);
53 s += wordCodeSeparator;
59 std::string bigramWordWithCodeToString(WordWithCodeView prev,
60 WordWithCodeView cur) {
62 const auto &word1 = std::get<0>(prev);
63 const auto &word2 = std::get<0>(cur);
64 const auto &code1 = std::get<1>(prev);
65 const auto &code2 = std::get<1>(cur);
66 const bool hasCode = !code1.empty() || !code2.empty();
67 s.reserve(word1.size() + word2.size() + code1.size() + code2.size() + 1 +
74 s += wordCodeSeparator;
83 using TrieType = DATrie<int32_t>;
86 WeightedTrie() =
default;
88 void clear() { trie_.clear(); }
90 const TrieType &trie()
const {
return trie_; }
92 int32_t weightedSize()
const {
return weightedSize_; }
94 int32_t freq(WordWithCodeView wordAndCode)
const {
98 TrieType::position_type pos = 0;
100 auto v = trie_.traverse(wordAndCode.first, pos);
101 if (TrieType::isValid(v)) {
103 }
else if (TrieType::isNoPath(v)) {
106 const char separator[] = {wordCodeSeparator,
'\0'};
107 v = trie_.traverse(separator, pos);
108 if (!TrieType::isNoPath(v)) {
109 if (!wordAndCode.second.empty()) {
110 v = trie_.traverse(wordAndCode.second, pos);
111 if (TrieType::isValid(v)) {
116 [&result](TrieType::value_type value,
size_t len,
117 TrieType::position_type ) {
130 int32_t freq(WordWithCodeView prev, WordWithCodeView next)
const {
131 const char bigramSeparatorString[] = {bigramSeparator,
'\0'};
132 TrieType::position_type pos = 0;
135 trie_.traverse(pos, prev.first, bigramSeparatorString, next.first);
136 if (TrieType::isValid(v)) {
138 }
else if (TrieType::isNoPath(v)) {
141 const char separator[] = {wordCodeSeparator,
'\0'};
142 v = trie_.traverse(separator, pos);
143 if (!TrieType::isNoPath(v)) {
144 if (!prev.second.empty() && !next.second.empty()) {
145 v = trie_.traverse(pos, prev.second, bigramSeparatorString,
147 if (TrieType::isValid(v)) {
151 if (!prev.second.empty()) {
152 v = trie_.traverse(pos, prev.second, bigramSeparatorString);
153 if (TrieType::isNoPath(v)) {
157 if (!next.second.empty()) {
159 [
this, &result, &next](TrieType::value_type value,
161 TrieType::position_type pos) {
162 if (len < next.second.size()) {
165 std::string codeInTrie;
166 trie().suffix(codeInTrie, len, pos);
167 if (codeInTrie.ends_with(next.second)) {
175 [&result](TrieType::value_type value,
size_t ,
176 TrieType::position_type ) {
187 void incFreqImpl(
const std::string &s, int32_t delta) {
188 trie_.update(s.data(), s.size(),
189 [delta](int32_t v) {
return v + delta; });
190 weightedSize_ += delta;
193 void incFreq(WordWithCodeView wordAndCode, int32_t delta) {
194 incFreqImpl(wordWithCodeToString(wordAndCode), delta);
197 void incFreq(WordWithCodeView prev, WordWithCodeView next, int32_t delta) {
198 incFreqImpl(bigramWordWithCodeToString(prev, next), delta);
201 void decFreqImpl(
const std::string &s, int32_t delta) {
202 auto v = trie_.exactMatchSearch(s.data(), s.size());
203 if (TrieType::isNoValue(v)) {
207 trie_.erase(s.data(), s.size());
211 trie_.set(s.data(), s.size(), v);
212 decWeightedSize(delta);
216 void decFreq(WordWithCodeView wordAndCode, int32_t delta) {
217 decFreqImpl(wordWithCodeToString(wordAndCode), delta);
220 void decFreq(WordWithCodeView prev, WordWithCodeView next, int32_t delta) {
221 decFreqImpl(bigramWordWithCodeToString(prev, next), delta);
224 void fillPredict(std::unordered_set<std::string> &words,
225 std::string_view word,
size_t maxSize)
const {
227 [
this, &words, maxSize](TrieType::value_type,
size_t len,
228 TrieType::position_type pos) {
230 trie().suffix(buf, len, pos);
231 auto separatorPos = buf.find(wordCodeSeparator);
232 if (separatorPos != std::string::npos) {
233 buf.erase(separatorPos);
236 if (buf ==
"<s>" || buf ==
"</s>") {
239 words.emplace(std::move(buf));
241 return maxSize <= 0 || words.size() < maxSize;
246 void decWeightedSize(int32_t v) {
248 weightedSize_ = std::max(weightedSize_, 0);
251 int32_t weightedSize_ = 0;
255 class HistoryBigramPool {
257 HistoryBigramPool(
size_t maxSize) : maxSize_(maxSize) {}
259 void load(std::istream &in) {
262 throw_if_io_fail(unmarshall(in, count));
265 throw_if_io_fail(unmarshall(in, size));
266 std::vector<WordWithCode> sentence;
269 throw_if_io_fail(unmarshallString(in, buffer));
270 std::string_view bufferView{buffer};
271 size_t separatorPos = bufferView.find(wordCodeSeparator);
272 if (separatorPos != std::string_view::npos) {
273 sentence.emplace_back(
274 std::string(bufferView.substr(0, separatorPos)),
275 std::string(bufferView.substr(separatorPos + 1)));
277 sentence.emplace_back(std::move(buffer),
"");
284 void loadText(std::istream &in) {
287 std::vector<std::string> lines;
288 while (std::getline(in, buf)) {
289 lines.emplace_back(buf);
290 if (lines.size() >= maxSize_) {
294 for (
auto &line : lines | std::views::reverse) {
295 std::string_view lineView{line};
296 std::vector<std::string> tokens;
297 bool withCode =
false;
298 while (!lineView.empty()) {
300 auto consumed = fcitx::stringutils::consumeMaybeEscapedValue(
301 lineView, FCITX_WHITESPACE, &token);
302 if (!consumed.empty()) {
303 tokens.push_back(std::move(token));
305 if (tokens.size() == 1 && !lineView.empty() &&
306 lineView.front() ==
'\t') {
312 if (tokens.size() % 2 != 0) {
315 add(std::views::iota(static_cast<size_t>(0),
317 std::views::transform([&tokens](
size_t i) {
318 return WordWithCode{tokens[i * 2], tokens[(i * 2) + 1]};
323 std::views::transform([](
const auto &word) -> WordWithCode {
324 std::vector<std::string> wordWithMaybeCode =
325 fcitx::stringutils::split(
327 fcitx::stringutils::SplitBehavior::KeepEmpty);
328 if (wordWithMaybeCode.size() == 2) {
329 return WordWithCode{wordWithMaybeCode[0],
330 wordWithMaybeCode[1]};
332 return WordWithCode{word,
""};
338 void save(std::ostream &out) {
339 uint32_t count = recent_.size();
340 throw_if_io_fail(marshall(out, count));
344 for (
auto &sentence : recent_ | std::views::reverse) {
345 uint32_t size = sentence.size();
346 throw_if_io_fail(marshall(out, size));
347 for (
const auto &s : sentence) {
348 throw_if_io_fail(marshallString(out, wordWithCodeToString(s)));
353 void dump(std::ostream &out)
const {
354 for (
const auto &sentence : recent_) {
356 bool hasCode = std::ranges::any_of(sentence, [](
const auto &item) {
357 return !std::get<1>(item).empty();
359 for (
const auto &s : sentence) {
365 out << fcitx::stringutils::escapeForValue(std::get<0>(s));
368 << fcitx::stringutils::escapeForValue(std::get<1>(s));
382 template <
typename R>
383 std::list<std::vector<WordWithCode>> add(
const R &sentence) {
384 std::list<std::vector<WordWithCode>> popedSentence;
385 if (sentence.empty()) {
386 return popedSentence;
389 if (std::ranges::any_of(sentence, [](
const auto &item) {
390 const auto &[word, code] = item;
391 return word.find(
'\0') != std::string::npos;
393 return popedSentence;
395 while (recent_.size() >= maxSize_) {
396 remove(recent_.back());
397 popedSentence.splice(popedSentence.end(), recent_,
398 std::prev(recent_.end()));
401 std::vector<WordWithCode> newSentence;
403 for (
auto iter = sentence.begin(), end = sentence.end(); iter != end;
405 unigram_.incFreq(*iter, delta);
406 auto next = std::ranges::next(iter);
408 incBigram(*iter, *next, delta);
410 newSentence.push_back(*iter);
412 recent_.push_front(std::move(newSentence));
413 unigram_.incFreq({
"<s>",
""}, delta);
414 unigram_.incFreq({
"</s>",
""}, delta);
415 incBigram({
"<s>",
""}, sentence.front(), delta);
416 incBigram(sentence.back(), {
"</s>",
""}, delta);
418 return popedSentence;
421 int32_t unigramFreq(WordWithCodeView s)
const {
return unigram_.freq(s); }
423 int32_t bigramFreq(WordWithCodeView s, WordWithCodeView s2)
const {
424 return bigram_.freq(s, s2);
427 bool isUnknown(WordWithCodeView word)
const {
428 return unigramFreq(word) == 0;
431 size_t maxSize()
const {
return maxSize_; }
433 size_t realSize()
const {
return recent_.size(); }
435 void forget(std::string_view word, std::string_view code) {
436 auto iter = recent_.begin();
437 while (iter != recent_.end()) {
439 iter->begin(), iter->end(), [word, code](
const auto &item) {
440 const auto &[w, c] = item;
441 return w == word && (code.empty() || c == code);
444 iter = recent_.erase(iter);
451 void fillPredict(std::unordered_set<std::string> &words,
452 std::string_view word,
size_t maxSize = 0)
const {
453 bigram_.fillPredict(words, word, maxSize);
456 bool maybeAppendToLatestSentence(
const std::vector<WordWithCode> &context,
457 std::vector<WordWithCode> &newSentence) {
458 if (recent_.empty() || newSentence.empty()) {
461 auto &latestSentence = recent_.front();
462 if (latestSentence.size() < context.size() ||
465 std::views::drop(latestSentence,
466 latestSentence.size() - context.size()))) {
471 decBigram(latestSentence.back(), {
"</s>",
""}, delta);
472 for (
auto &item : newSentence) {
473 unigram_.incFreq(item, delta);
474 incBigram(latestSentence.back(), item, delta);
475 latestSentence.push_back(std::move(item));
477 incBigram(latestSentence.back(), {
"</s>",
""}, delta);
483 template <
typename R>
484 void remove(
const R &sentence) {
486 for (
auto iter = sentence.begin(), end = sentence.end(); iter != end;
488 unigram_.decFreq(*iter, delta);
489 auto next = std::next(iter);
491 decBigram(*iter, *next, delta);
494 decBigram({
"<s>",
""}, sentence.front(), delta);
495 decBigram(sentence.back(), {
"</s>",
""}, delta);
498 void decBigram(WordWithCodeView s1, WordWithCodeView s2, int32_t delta) {
499 bigram_.decFreq(s1, s2, delta);
502 void incBigram(WordWithCodeView s1, WordWithCodeView s2,
int delta) {
503 bigram_.incFreq(s1, s2, delta);
506 const size_t maxSize_;
510 std::list<std::vector<WordWithCode>> recent_;
513 WeightedTrie unigram_;
514 WeightedTrie bigram_;
529 void populateSentence(std::list<std::vector<WordWithCode>> popedSentence) {
530 for (
size_t i = 1; !popedSentence.empty() && i < pools_.size(); i++) {
531 std::list<std::vector<WordWithCode>> nextSentences;
532 while (!popedSentence.empty()) {
533 auto newPopedSentence = pools_[i].add(popedSentence.front());
534 popedSentence.pop_front();
535 nextSentences.splice(nextSentences.end(), newPopedSentence);
537 popedSentence = std::move(nextSentences);
541 float unigramFreq(WordWithCodeView word)
const {
542 assert(pools_.size() == poolWeight_.size());
544 for (
size_t i = 0; i < pools_.size(); i++) {
545 freq += pools_[i].unigramFreq(word) * poolWeight_[i];
550 float bigramFreq(WordWithCodeView prev, WordWithCodeView cur)
const {
551 assert(pools_.size() == poolWeight_.size());
553 for (
size_t i = 0; i < pools_.size(); i++) {
554 freq += pools_[i].bigramFreq(prev, cur) * poolWeight_[i];
559 float unigramSize()
const {
561 for (
size_t i = 0; i < pools_.size(); i++) {
562 size += pools_[i].maxSize() * poolWeight_[i];
569 std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY);
570 bool useOnlyUnigram_ =
false;
571 std::vector<HistoryBigramPool> pools_;
572 std::vector<float> poolWeight_;
575 HistoryBigram::HistoryBigram()
576 : d_ptr(std::make_unique<HistoryBigramPrivate>()) {
578 const float p = 1.0 / (1 + HISTORY_BIGRAM_ALPHA_VALUE);
579 constexpr std::array<int, 3> poolSize = {128, 8192, 65536};
580 d->pools_.reserve(poolSize.size());
581 d->poolWeight_.reserve(poolSize.size());
582 for (
auto size : poolSize) {
583 d->pools_.emplace_back(size);
584 float portion = 1.0F;
585 if (d->pools_.size() != poolSize.size()) {
588 portion *= std::pow(p, d->pools_.size() - 1);
589 d->poolWeight_.push_back(portion / d->pools_.back().maxSize());
592 std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY));
599 d->unknown_ = unknown;
602 float HistoryBigram::unknownPenalty()
const {
607 void HistoryBigram::setUseOnlyUnigram(
bool useOnlyUnigram) {
609 d->useOnlyUnigram_ = useOnlyUnigram;
612 bool HistoryBigram::useOnlyUnigram()
const {
614 return d->useOnlyUnigram_;
618 addWithCode(sentence,
nullptr);
621 void HistoryBigram::addWithCode(
623 const ValidationCodeExtractor &validationCodeExtractor) {
625 d->populateSentence(d->pools_[0].add(
626 sentence.sentence() |
627 std::views::transform(
628 [&validationCodeExtractor](
const auto &item) -> WordWithCode {
629 return {item->word(), validationCodeExtractor
630 ? validationCodeExtractor(item)
635 void HistoryBigram::add(
const std::vector<std::string> &sentence) {
637 d->populateSentence(d->pools_[0].add(
638 sentence | std::views::transform([](
const auto &word) -> WordWithCode {
639 return WordWithCode{word,
""};
643 void HistoryBigram::addWithCode(
644 const std::vector<WordWithCode> &sentenceWithValidationCode) {
646 d->populateSentence(d->pools_[0].add(sentenceWithValidationCode));
649 bool HistoryBigram::isUnknown(std::string_view v)
const {
651 return std::ranges::all_of(d->pools_, [v](
const HistoryBigramPool &pool) {
652 return pool.isUnknown({v,
""});
656 float HistoryBigram::score(std::string_view prev, std::string_view cur)
const {
657 return scoreWithCode({prev,
""}, {cur,
""});
660 float HistoryBigram::scoreWithCode(WordWithCodeView prev,
661 WordWithCodeView cur)
const {
663 if (prev.first.empty()) {
666 if (cur.first.empty()) {
670 auto uf0 = d->unigramFreq(prev);
671 auto bf = d->bigramFreq(prev, cur);
672 auto uf1 = d->unigramFreq(cur);
674 float bigramWeight = d->useOnlyUnigram_ ? 0.0F : 0.8F;
677 pr += bigramWeight * bf / (uf0 + (d->poolWeight_[0] / 2.0F));
678 pr += (1.0F - bigramWeight) * uf1 /
679 (d->unigramSize() + (d->poolWeight_[0] / 2.0F));
681 pr = std::min<float>(pr, 1.0F);
686 return std::log10(pr);
689 void HistoryBigram::load(std::istream &in) {
692 uint32_t version = 0;
693 throw_if_io_fail(unmarshall(in, magic));
694 if (magic != historyBinaryFormatMagic) {
695 throw std::invalid_argument(
"Invalid history magic.");
697 throw_if_io_fail(unmarshall(in, version));
700 std::ranges::for_each(d->pools_ | std::views::take(2),
701 [&in](
auto &pool) { pool.load(in); });
704 std::ranges::for_each(d->pools_, [&in](
auto &pool) { pool.load(in); });
707 case historyBinaryFormatVersion:
711 readZSTDCompressed(in, [d](std::istream &compressIn) {
712 std::ranges::for_each(d->pools_, [&compressIn](
auto &pool) {
713 pool.load(compressIn);
718 throw std::invalid_argument(
"Invalid history version.");
722 void HistoryBigram::loadText(std::istream &in) {
724 std::ranges::for_each(d->pools_, [&in](
auto &pool) { pool.loadText(in); });
727 void HistoryBigram::save(std::ostream &out) {
729 throw_if_io_fail(marshall(out, historyBinaryFormatMagic));
730 throw_if_io_fail(marshall(out, historyBinaryFormatVersion));
732 writeZSTDCompressed(out, [d](std::ostream &compressOut) {
733 std::ranges::for_each(
734 d->pools_, [&compressOut](
auto &pool) { pool.save(compressOut); });
738 void HistoryBigram::dump(std::ostream &out) {
740 std::ranges::for_each(d->pools_,
741 [&out](
const auto &pool) { pool.dump(out); });
744 void HistoryBigram::clear() {
746 std::ranges::for_each(d->pools_, std::mem_fn(&HistoryBigramPool::clear));
749 void HistoryBigram::forget(std::string_view word) { forget(word,
""); }
751 void HistoryBigram::forget(std::string_view word, std::string_view code) {
753 std::ranges::for_each(
754 d->pools_, [word, code](
auto &pool) { pool.forget(word, code); });
758 const std::vector<std::string> &sentence,
759 size_t maxSize)
const {
761 if (maxSize > 0 && words.size() >= maxSize) {
765 if (!sentence.empty()) {
766 lookup = sentence.back();
770 lookup += bigramSeparator;
771 std::ranges::for_each(
772 d->pools_, [&words, &lookup, maxSize](
const HistoryBigramPool &pool) {
773 pool.fillPredict(words, lookup, maxSize);
777 bool HistoryBigram::containsBigram(std::string_view prev,
778 std::string_view cur)
const {
780 return std::ranges::any_of(
781 d->pools_, [&prev, &cur](
const HistoryBigramPool &pool) {
782 return pool.bigramFreq({prev,
""}, {cur,
""}) > 0;
788 return d->unigramFreq(word);
792 WordWithCodeView cur)
const {
794 return d->bigramFreq(prev, cur);
800 for (
const auto &pool : d->pools_) {
801 freq += pool.unigramFreq(word);
807 WordWithCodeView cur)
const {
810 for (
const auto &pool : d->pools_) {
811 freq += pool.bigramFreq(prev, cur);
816 float HistoryBigram::score(
const WordNode *prev,
const WordNode *cur)
const {
817 return scoreWithCode(prev, cur,
nullptr);
820 float HistoryBigram::scoreWithCode(
822 const ValidationCodeExtractor &extractor)
const {
823 return scoreWithCode(
824 {prev ? prev->word() :
"", extractor && prev ? extractor(prev) :
""},
825 {cur ? cur->word() :
"", extractor && cur ? extractor(cur) :
""});
828 void HistoryBigram::addWithContext(
const std::vector<WordWithCode> &context,
829 std::vector<WordWithCode> newSentence) {
831 if (context.empty() ||
832 !d->pools_[0].maybeAppendToLatestSentence(context, newSentence)) {
833 addWithCode(newSentence);
int32_t rawUnigramFrequency(WordWithCodeView word) const
Query the raw frequency of the unigram.
float bigramFrequency(WordWithCodeView prev, WordWithCodeView cur) const
Query the weighted frequency of the bigram.
Provide a DATrie implementation.
float unigramFrequency(WordWithCodeView word) const
Query the weighted frequency of the unigram.
void fillPredict(std::unordered_set< std::string > &words, const std::vector< std::string > &sentence, size_t maxSize) const
Fill the prediction based on current sentence.
int32_t rawBigramFrequency(WordWithCodeView prev, WordWithCodeView cur) const
Query the raw frequency of the bigram.