libime
historybigram.cpp
1 /*
2  * SPDX-FileCopyrightText: 2017-2017 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 #include "historybigram.h"
7 #include <algorithm>
8 #include <array>
9 #include <cassert>
10 #include <cmath>
11 #include <cstddef>
12 #include <cstdint>
13 #include <functional>
14 #include <istream>
15 #include <iterator>
16 #include <list>
17 #include <memory>
18 #include <ostream>
19 #include <ranges>
20 #include <stdexcept>
21 #include <string>
22 #include <string_view>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 #include <fcitx-utils/macros.h>
27 #include <fcitx-utils/stringutils.h>
28 #include "constants.h"
29 #include "datrie.h"
30 #include "lattice.h"
31 #include "utils_p.h"
32 #include "zstdfilter.h"
33 
34 namespace libime {
35 
36 namespace {
37 
38 using WordWithCode = HistoryBigram::WordWithCode;
39 using WordWithCodeView = HistoryBigram::WordWithCodeView;
40 
41 constexpr uint32_t historyBinaryFormatMagic = 0x000fc315;
42 constexpr uint32_t historyBinaryFormatVersion = 0x4;
43 constexpr char bigramSeparator = '\x01';
44 constexpr char wordCodeSeparator = '\x02';
45 
46 std::string wordWithCodeToString(WordWithCodeView wordAndCode) {
47  std::string s{std::get<0>(wordAndCode)};
48  if (s.empty()) {
49  return s;
50  }
51  auto code = std::get<1>(wordAndCode);
52  if (!code.empty()) {
53  s += wordCodeSeparator;
54  s += code;
55  }
56  return s;
57 }
58 
59 std::string bigramWordWithCodeToString(WordWithCodeView prev,
60  WordWithCodeView cur) {
61  std::string s;
62  const auto &word1 = std::get<0>(prev);
63  const auto &word2 = std::get<0>(cur);
64  const auto &code1 = std::get<1>(prev);
65  const auto &code2 = std::get<1>(cur);
66  const bool hasCode = !code1.empty() || !code2.empty();
67  s.reserve(word1.size() + word2.size() + code1.size() + code2.size() + 1 +
68  (hasCode ? 2 : 0));
69  s += word1;
70  s += bigramSeparator;
71  s += word2;
72 
73  if (hasCode) {
74  s += wordCodeSeparator;
75  s += code1;
76  s += bigramSeparator;
77  s += code2;
78  }
79  return s;
80 }
81 
82 struct WeightedTrie {
83  using TrieType = DATrie<int32_t>;
84 
85 public:
86  WeightedTrie() = default;
87 
88  void clear() { trie_.clear(); }
89 
90  const TrieType &trie() const { return trie_; }
91 
92  int32_t weightedSize() const { return weightedSize_; }
93 
94  int32_t freq(WordWithCodeView wordAndCode) const {
95  // If query with code, the match will be {word, ""} + {word, code}.
96  // If query without code, the match will be {word, ""} + {word,
97  // separator}.
98  TrieType::position_type pos = 0;
99  auto result = 0;
100  auto v = trie_.traverse(wordAndCode.first, pos);
101  if (TrieType::isValid(v)) {
102  result += v;
103  } else if (TrieType::isNoPath(v)) {
104  return 0;
105  }
106  const char separator[] = {wordCodeSeparator, '\0'};
107  v = trie_.traverse(separator, pos);
108  if (!TrieType::isNoPath(v)) {
109  if (!wordAndCode.second.empty()) {
110  v = trie_.traverse(wordAndCode.second, pos);
111  if (TrieType::isValid(v)) {
112  result += v;
113  }
114  } else {
115  trie_.foreach(
116  [&result](TrieType::value_type value, size_t len,
117  TrieType::position_type /*pos*/) {
118  if (len == 0) {
119  return true;
120  }
121  result += value;
122  return true;
123  },
124  pos);
125  }
126  }
127  return result;
128  }
129 
130  int32_t freq(WordWithCodeView prev, WordWithCodeView next) const {
131  const char bigramSeparatorString[] = {bigramSeparator, '\0'};
132  TrieType::position_type pos = 0;
133  auto result = 0;
134  auto v =
135  trie_.traverse(pos, prev.first, bigramSeparatorString, next.first);
136  if (TrieType::isValid(v)) {
137  result += v;
138  } else if (TrieType::isNoPath(v)) {
139  return 0;
140  }
141  const char separator[] = {wordCodeSeparator, '\0'};
142  v = trie_.traverse(separator, pos);
143  if (!TrieType::isNoPath(v)) {
144  if (!prev.second.empty() && !next.second.empty()) {
145  v = trie_.traverse(pos, prev.second, bigramSeparatorString,
146  next.second);
147  if (TrieType::isValid(v)) {
148  result += v;
149  }
150  } else {
151  if (!prev.second.empty()) {
152  v = trie_.traverse(pos, prev.second, bigramSeparatorString);
153  if (TrieType::isNoPath(v)) {
154  return result;
155  }
156  }
157  if (!next.second.empty()) {
158  trie_.foreach(
159  [this, &result, &next](TrieType::value_type value,
160  size_t len,
161  TrieType::position_type pos) {
162  if (len < next.second.size()) {
163  return true;
164  }
165  std::string codeInTrie;
166  trie().suffix(codeInTrie, len, pos);
167  if (codeInTrie.ends_with(next.second)) {
168  result += value;
169  }
170  return true;
171  },
172  pos);
173  } else {
174  trie_.foreach(
175  [&result](TrieType::value_type value, size_t /*len*/,
176  TrieType::position_type /*pos*/) {
177  result += value;
178  return true;
179  },
180  pos);
181  }
182  }
183  }
184  return result;
185  }
186 
187  void incFreqImpl(const std::string &s, int32_t delta) {
188  trie_.update(s.data(), s.size(),
189  [delta](int32_t v) { return v + delta; });
190  weightedSize_ += delta;
191  }
192 
193  void incFreq(WordWithCodeView wordAndCode, int32_t delta) {
194  incFreqImpl(wordWithCodeToString(wordAndCode), delta);
195  }
196 
197  void incFreq(WordWithCodeView prev, WordWithCodeView next, int32_t delta) {
198  incFreqImpl(bigramWordWithCodeToString(prev, next), delta);
199  }
200 
201  void decFreqImpl(const std::string &s, int32_t delta) {
202  auto v = trie_.exactMatchSearch(s.data(), s.size());
203  if (TrieType::isNoValue(v)) {
204  return;
205  }
206  if (v <= delta) {
207  trie_.erase(s.data(), s.size());
208  decWeightedSize(v);
209  } else {
210  v -= delta;
211  trie_.set(s.data(), s.size(), v);
212  decWeightedSize(delta);
213  }
214  }
215 
216  void decFreq(WordWithCodeView wordAndCode, int32_t delta) {
217  decFreqImpl(wordWithCodeToString(wordAndCode), delta);
218  }
219 
220  void decFreq(WordWithCodeView prev, WordWithCodeView next, int32_t delta) {
221  decFreqImpl(bigramWordWithCodeToString(prev, next), delta);
222  }
223 
224  void fillPredict(std::unordered_set<std::string> &words,
225  std::string_view word, size_t maxSize) const {
226  trie_.foreach(word,
227  [this, &words, maxSize](TrieType::value_type, size_t len,
228  TrieType::position_type pos) {
229  std::string buf;
230  trie().suffix(buf, len, pos);
231  auto separatorPos = buf.find(wordCodeSeparator);
232  if (separatorPos != std::string::npos) {
233  buf.erase(separatorPos);
234  }
235  // Skip special word.
236  if (buf == "<s>" || buf == "</s>") {
237  return true;
238  }
239  words.emplace(std::move(buf));
240 
241  return maxSize <= 0 || words.size() < maxSize;
242  });
243  }
244 
245 private:
246  void decWeightedSize(int32_t v) {
247  weightedSize_ -= v;
248  weightedSize_ = std::max(weightedSize_, 0);
249  }
250 
251  int32_t weightedSize_ = 0;
252  TrieType trie_;
253 };
254 
255 class HistoryBigramPool {
256 public:
257  HistoryBigramPool(size_t maxSize) : maxSize_(maxSize) {}
258 
259  void load(std::istream &in) {
260  clear();
261  uint32_t count = 0;
262  throw_if_io_fail(unmarshall(in, count));
263  while (count--) {
264  uint32_t size = 0;
265  throw_if_io_fail(unmarshall(in, size));
266  std::vector<WordWithCode> sentence;
267  while (size--) {
268  std::string buffer;
269  throw_if_io_fail(unmarshallString(in, buffer));
270  std::string_view bufferView{buffer};
271  size_t separatorPos = bufferView.find(wordCodeSeparator);
272  if (separatorPos != std::string_view::npos) {
273  sentence.emplace_back(
274  std::string(bufferView.substr(0, separatorPos)),
275  std::string(bufferView.substr(separatorPos + 1)));
276  } else {
277  sentence.emplace_back(std::move(buffer), "");
278  }
279  }
280  add(sentence);
281  }
282  }
283 
284  void loadText(std::istream &in) {
285  clear();
286  std::string buf;
287  std::vector<std::string> lines;
288  while (std::getline(in, buf)) {
289  lines.emplace_back(buf);
290  if (lines.size() >= maxSize_) {
291  break;
292  }
293  }
294  for (auto &line : lines | std::views::reverse) {
295  std::string_view lineView{line};
296  std::vector<std::string> tokens;
297  bool withCode = false;
298  while (!lineView.empty()) {
299  std::string token;
300  auto consumed = fcitx::stringutils::consumeMaybeEscapedValue(
301  lineView, FCITX_WHITESPACE, &token);
302  if (!consumed.empty()) {
303  tokens.push_back(std::move(token));
304  }
305  if (tokens.size() == 1 && !lineView.empty() &&
306  lineView.front() == '\t') {
307  withCode = true;
308  }
309  }
310 
311  if (withCode) {
312  if (tokens.size() % 2 != 0) {
313  continue;
314  }
315  add(std::views::iota(static_cast<size_t>(0),
316  tokens.size() / 2) |
317  std::views::transform([&tokens](size_t i) {
318  return WordWithCode{tokens[i * 2], tokens[(i * 2) + 1]};
319  }));
320 
321  } else {
322  add(tokens |
323  std::views::transform([](const auto &word) -> WordWithCode {
324  std::vector<std::string> wordWithMaybeCode =
325  fcitx::stringutils::split(
326  word, "\t",
327  fcitx::stringutils::SplitBehavior::KeepEmpty);
328  if (wordWithMaybeCode.size() == 2) {
329  return WordWithCode{wordWithMaybeCode[0],
330  wordWithMaybeCode[1]};
331  }
332  return WordWithCode{word, ""};
333  }));
334  }
335  }
336  }
337 
338  void save(std::ostream &out) {
339  uint32_t count = recent_.size();
340  throw_if_io_fail(marshall(out, count));
341  // When we do save, we need to reverse the history order.
342  // Because loading the history is done by call "add", which basically
343  // expect the history from old to new.
344  for (auto &sentence : recent_ | std::views::reverse) {
345  uint32_t size = sentence.size();
346  throw_if_io_fail(marshall(out, size));
347  for (const auto &s : sentence) {
348  throw_if_io_fail(marshallString(out, wordWithCodeToString(s)));
349  }
350  }
351  }
352 
353  void dump(std::ostream &out) const {
354  for (const auto &sentence : recent_) {
355  bool first = true;
356  bool hasCode = std::ranges::any_of(sentence, [](const auto &item) {
357  return !std::get<1>(item).empty();
358  });
359  for (const auto &s : sentence) {
360  if (first) {
361  first = false;
362  } else {
363  out << " ";
364  }
365  out << fcitx::stringutils::escapeForValue(std::get<0>(s));
366  if (hasCode) {
367  out << "\t"
368  << fcitx::stringutils::escapeForValue(std::get<1>(s));
369  }
370  }
371  out << '\n';
372  }
373  }
374 
375  void clear() {
376  recent_.clear();
377  unigram_.clear();
378  bigram_.clear();
379  size_ = 0;
380  }
381 
382  template <typename R>
383  std::list<std::vector<WordWithCode>> add(const R &sentence) {
384  std::list<std::vector<WordWithCode>> popedSentence;
385  if (sentence.empty()) {
386  return popedSentence;
387  }
388  // Validate data.
389  if (std::ranges::any_of(sentence, [](const auto &item) {
390  const auto &[word, code] = item;
391  return word.find('\0') != std::string::npos;
392  })) {
393  return popedSentence;
394  }
395  while (recent_.size() >= maxSize_) {
396  remove(recent_.back());
397  popedSentence.splice(popedSentence.end(), recent_,
398  std::prev(recent_.end()));
399  }
400 
401  std::vector<WordWithCode> newSentence;
402  auto delta = 1;
403  for (auto iter = sentence.begin(), end = sentence.end(); iter != end;
404  iter++) {
405  unigram_.incFreq(*iter, delta);
406  auto next = std::ranges::next(iter);
407  if (next != end) {
408  incBigram(*iter, *next, delta);
409  }
410  newSentence.push_back(*iter);
411  }
412  recent_.push_front(std::move(newSentence));
413  unigram_.incFreq({"<s>", ""}, delta);
414  unigram_.incFreq({"</s>", ""}, delta);
415  incBigram({"<s>", ""}, sentence.front(), delta);
416  incBigram(sentence.back(), {"</s>", ""}, delta);
417 
418  return popedSentence;
419  }
420 
421  int32_t unigramFreq(WordWithCodeView s) const { return unigram_.freq(s); }
422 
423  int32_t bigramFreq(WordWithCodeView s, WordWithCodeView s2) const {
424  return bigram_.freq(s, s2);
425  }
426 
427  bool isUnknown(WordWithCodeView word) const {
428  return unigramFreq(word) == 0;
429  }
430 
431  size_t maxSize() const { return maxSize_; }
432 
433  size_t realSize() const { return recent_.size(); }
434 
435  void forget(std::string_view word, std::string_view code) {
436  auto iter = recent_.begin();
437  while (iter != recent_.end()) {
438  if (std::find_if(
439  iter->begin(), iter->end(), [word, code](const auto &item) {
440  const auto &[w, c] = item;
441  return w == word && (code.empty() || c == code);
442  }) != iter->end()) {
443  remove(*iter);
444  iter = recent_.erase(iter);
445  } else {
446  ++iter;
447  }
448  }
449  }
450 
451  void fillPredict(std::unordered_set<std::string> &words,
452  std::string_view word, size_t maxSize = 0) const {
453  bigram_.fillPredict(words, word, maxSize);
454  }
455 
456  bool maybeAppendToLatestSentence(const std::vector<WordWithCode> &context,
457  std::vector<WordWithCode> &newSentence) {
458  if (recent_.empty() || newSentence.empty()) {
459  return false;
460  }
461  auto &latestSentence = recent_.front();
462  if (latestSentence.size() < context.size() ||
463  !std::ranges::equal(
464  context,
465  std::views::drop(latestSentence,
466  latestSentence.size() - context.size()))) {
467  return false;
468  }
469 
470  const int delta = 1;
471  decBigram(latestSentence.back(), {"</s>", ""}, delta);
472  for (auto &item : newSentence) {
473  unigram_.incFreq(item, delta);
474  incBigram(latestSentence.back(), item, delta);
475  latestSentence.push_back(std::move(item));
476  }
477  incBigram(latestSentence.back(), {"</s>", ""}, delta);
478 
479  return true;
480  }
481 
482 private:
483  template <typename R>
484  void remove(const R &sentence) {
485  const int delta = 1;
486  for (auto iter = sentence.begin(), end = sentence.end(); iter != end;
487  iter++) {
488  unigram_.decFreq(*iter, delta);
489  auto next = std::next(iter);
490  if (next != end) {
491  decBigram(*iter, *next, delta);
492  }
493  }
494  decBigram({"<s>", ""}, sentence.front(), delta);
495  decBigram(sentence.back(), {"</s>", ""}, delta);
496  }
497 
498  void decBigram(WordWithCodeView s1, WordWithCodeView s2, int32_t delta) {
499  bigram_.decFreq(s1, s2, delta);
500  }
501 
502  void incBigram(WordWithCodeView s1, WordWithCodeView s2, int delta) {
503  bigram_.incFreq(s1, s2, delta);
504  }
505 
506  const size_t maxSize_;
507 
508  // Used when maxSize_ != 0.
509  size_t size_ = 0;
510  std::list<std::vector<WordWithCode>> recent_;
511 
512  // Used for look up
513  WeightedTrie unigram_;
514  WeightedTrie bigram_;
515 };
516 
517 } // namespace
518 
519 // We define the frequency as following.
520 // (1 - p) the frequency belongs to first pool.
521 // p * (1 - p) Second pool
522 // p^2 * (1 - p) Third pool
523 // ...
524 // p^(n-1) n-th pool.
525 // In sum, it's (1-p) * p^(i - 1)
526 // And then we define alpha as p = 1 / (1 + alpha).
528 public:
529  void populateSentence(std::list<std::vector<WordWithCode>> popedSentence) {
530  for (size_t i = 1; !popedSentence.empty() && i < pools_.size(); i++) {
531  std::list<std::vector<WordWithCode>> nextSentences;
532  while (!popedSentence.empty()) {
533  auto newPopedSentence = pools_[i].add(popedSentence.front());
534  popedSentence.pop_front();
535  nextSentences.splice(nextSentences.end(), newPopedSentence);
536  }
537  popedSentence = std::move(nextSentences);
538  }
539  }
540 
541  float unigramFreq(WordWithCodeView word) const {
542  assert(pools_.size() == poolWeight_.size());
543  float freq = 0;
544  for (size_t i = 0; i < pools_.size(); i++) {
545  freq += pools_[i].unigramFreq(word) * poolWeight_[i];
546  }
547  return freq;
548  }
549 
550  float bigramFreq(WordWithCodeView prev, WordWithCodeView cur) const {
551  assert(pools_.size() == poolWeight_.size());
552  float freq = 0;
553  for (size_t i = 0; i < pools_.size(); i++) {
554  freq += pools_[i].bigramFreq(prev, cur) * poolWeight_[i];
555  }
556  return freq;
557  }
558 
559  float unigramSize() const {
560  float size = 0;
561  for (size_t i = 0; i < pools_.size(); i++) {
562  size += pools_[i].maxSize() * poolWeight_[i];
563  }
564  return size;
565  }
566 
567  // A log probabilty.
568  float unknown_ =
569  std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY);
570  bool useOnlyUnigram_ = false;
571  std::vector<HistoryBigramPool> pools_;
572  std::vector<float> poolWeight_;
573 };
574 
575 HistoryBigram::HistoryBigram()
576  : d_ptr(std::make_unique<HistoryBigramPrivate>()) {
577  FCITX_D();
578  const float p = 1.0 / (1 + HISTORY_BIGRAM_ALPHA_VALUE);
579  constexpr std::array<int, 3> poolSize = {128, 8192, 65536};
580  d->pools_.reserve(poolSize.size());
581  d->poolWeight_.reserve(poolSize.size());
582  for (auto size : poolSize) {
583  d->pools_.emplace_back(size);
584  float portion = 1.0F;
585  if (d->pools_.size() != poolSize.size()) {
586  portion *= 1 - p;
587  }
588  portion *= std::pow(p, d->pools_.size() - 1);
589  d->poolWeight_.push_back(portion / d->pools_.back().maxSize());
590  }
591  setUnknownPenalty(
592  std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY));
593 }
594 
595 FCITX_DEFINE_DEFAULT_DTOR_AND_MOVE(HistoryBigram)
596 
597 void HistoryBigram::setUnknownPenalty(float unknown) {
598  FCITX_D();
599  d->unknown_ = unknown;
600 }
601 
602 float HistoryBigram::unknownPenalty() const {
603  FCITX_D();
604  return d->unknown_;
605 }
606 
607 void HistoryBigram::setUseOnlyUnigram(bool useOnlyUnigram) {
608  FCITX_D();
609  d->useOnlyUnigram_ = useOnlyUnigram;
610 }
611 
612 bool HistoryBigram::useOnlyUnigram() const {
613  FCITX_D();
614  return d->useOnlyUnigram_;
615 }
616 
617 void HistoryBigram::add(const libime::SentenceResult &sentence) {
618  addWithCode(sentence, nullptr);
619 }
620 
621 void HistoryBigram::addWithCode(
622  const libime::SentenceResult &sentence,
623  const ValidationCodeExtractor &validationCodeExtractor) {
624  FCITX_D();
625  d->populateSentence(d->pools_[0].add(
626  sentence.sentence() |
627  std::views::transform(
628  [&validationCodeExtractor](const auto &item) -> WordWithCode {
629  return {item->word(), validationCodeExtractor
630  ? validationCodeExtractor(item)
631  : ""};
632  })));
633 }
634 
635 void HistoryBigram::add(const std::vector<std::string> &sentence) {
636  FCITX_D();
637  d->populateSentence(d->pools_[0].add(
638  sentence | std::views::transform([](const auto &word) -> WordWithCode {
639  return WordWithCode{word, ""};
640  })));
641 }
642 
643 void HistoryBigram::addWithCode(
644  const std::vector<WordWithCode> &sentenceWithValidationCode) {
645  FCITX_D();
646  d->populateSentence(d->pools_[0].add(sentenceWithValidationCode));
647 }
648 
649 bool HistoryBigram::isUnknown(std::string_view v) const {
650  FCITX_D();
651  return std::ranges::all_of(d->pools_, [v](const HistoryBigramPool &pool) {
652  return pool.isUnknown({v, ""});
653  });
654 }
655 
656 float HistoryBigram::score(std::string_view prev, std::string_view cur) const {
657  return scoreWithCode({prev, ""}, {cur, ""});
658 }
659 
660 float HistoryBigram::scoreWithCode(WordWithCodeView prev,
661  WordWithCodeView cur) const {
662  FCITX_D();
663  if (prev.first.empty()) {
664  prev.first = "<s>";
665  }
666  if (cur.first.empty()) {
667  cur.first = "<unk>";
668  }
669 
670  auto uf0 = d->unigramFreq(prev);
671  auto bf = d->bigramFreq(prev, cur);
672  auto uf1 = d->unigramFreq(cur);
673 
674  float bigramWeight = d->useOnlyUnigram_ ? 0.0F : 0.8F;
675  // add 0.5 to avoid div 0
676  float pr = 0.0F;
677  pr += bigramWeight * bf / (uf0 + (d->poolWeight_[0] / 2.0F));
678  pr += (1.0F - bigramWeight) * uf1 /
679  (d->unigramSize() + (d->poolWeight_[0] / 2.0F));
680 
681  pr = std::min<float>(pr, 1.0F);
682  if (pr == 0) {
683  return d->unknown_;
684  }
685 
686  return std::log10(pr);
687 }
688 
689 void HistoryBigram::load(std::istream &in) {
690  FCITX_D();
691  uint32_t magic = 0;
692  uint32_t version = 0;
693  throw_if_io_fail(unmarshall(in, magic));
694  if (magic != historyBinaryFormatMagic) {
695  throw std::invalid_argument("Invalid history magic.");
696  }
697  throw_if_io_fail(unmarshall(in, version));
698  switch (version) {
699  case 1:
700  std::ranges::for_each(d->pools_ | std::views::take(2),
701  [&in](auto &pool) { pool.load(in); });
702  break;
703  case 2:
704  std::ranges::for_each(d->pools_, [&in](auto &pool) { pool.load(in); });
705  break;
706  case 3:
707  case historyBinaryFormatVersion:
708  // For version 3 and version 4, the format is the same, but version 4
709  // contains additional code data, bump the version to it not backward
710  // compatible with version 3.
711  readZSTDCompressed(in, [d](std::istream &compressIn) {
712  std::ranges::for_each(d->pools_, [&compressIn](auto &pool) {
713  pool.load(compressIn);
714  });
715  });
716  break;
717  default:
718  throw std::invalid_argument("Invalid history version.");
719  }
720 }
721 
722 void HistoryBigram::loadText(std::istream &in) {
723  FCITX_D();
724  std::ranges::for_each(d->pools_, [&in](auto &pool) { pool.loadText(in); });
725 }
726 
727 void HistoryBigram::save(std::ostream &out) {
728  FCITX_D();
729  throw_if_io_fail(marshall(out, historyBinaryFormatMagic));
730  throw_if_io_fail(marshall(out, historyBinaryFormatVersion));
731 
732  writeZSTDCompressed(out, [d](std::ostream &compressOut) {
733  std::ranges::for_each(
734  d->pools_, [&compressOut](auto &pool) { pool.save(compressOut); });
735  });
736 }
737 
738 void HistoryBigram::dump(std::ostream &out) {
739  FCITX_D();
740  std::ranges::for_each(d->pools_,
741  [&out](const auto &pool) { pool.dump(out); });
742 }
743 
744 void HistoryBigram::clear() {
745  FCITX_D();
746  std::ranges::for_each(d->pools_, std::mem_fn(&HistoryBigramPool::clear));
747 }
748 
749 void HistoryBigram::forget(std::string_view word) { forget(word, ""); }
750 
751 void HistoryBigram::forget(std::string_view word, std::string_view code) {
752  FCITX_D();
753  std::ranges::for_each(
754  d->pools_, [word, code](auto &pool) { pool.forget(word, code); });
755 }
756 
757 void HistoryBigram::fillPredict(std::unordered_set<std::string> &words,
758  const std::vector<std::string> &sentence,
759  size_t maxSize) const {
760  FCITX_D();
761  if (maxSize > 0 && words.size() >= maxSize) {
762  return;
763  }
764  std::string lookup;
765  if (!sentence.empty()) {
766  lookup = sentence.back();
767  } else {
768  lookup = "<s>";
769  }
770  lookup += bigramSeparator;
771  std::ranges::for_each(
772  d->pools_, [&words, &lookup, maxSize](const HistoryBigramPool &pool) {
773  pool.fillPredict(words, lookup, maxSize);
774  });
775 }
776 
777 bool HistoryBigram::containsBigram(std::string_view prev,
778  std::string_view cur) const {
779  FCITX_D();
780  return std::ranges::any_of(
781  d->pools_, [&prev, &cur](const HistoryBigramPool &pool) {
782  return pool.bigramFreq({prev, ""}, {cur, ""}) > 0;
783  });
784 }
785 
786 float HistoryBigram::unigramFrequency(WordWithCodeView word) const {
787  FCITX_D();
788  return d->unigramFreq(word);
789 }
790 
791 float HistoryBigram::bigramFrequency(WordWithCodeView prev,
792  WordWithCodeView cur) const {
793  FCITX_D();
794  return d->bigramFreq(prev, cur);
795 }
796 
797 int32_t HistoryBigram::rawUnigramFrequency(WordWithCodeView word) const {
798  FCITX_D();
799  int32_t freq = 0;
800  for (const auto &pool : d->pools_) {
801  freq += pool.unigramFreq(word);
802  }
803  return freq;
804 }
805 
806 int32_t HistoryBigram::rawBigramFrequency(WordWithCodeView prev,
807  WordWithCodeView cur) const {
808  FCITX_D();
809  int32_t freq = 0;
810  for (const auto &pool : d->pools_) {
811  freq += pool.bigramFreq(prev, cur);
812  }
813  return freq;
814 }
815 
816 float HistoryBigram::score(const WordNode *prev, const WordNode *cur) const {
817  return scoreWithCode(prev, cur, nullptr);
818 }
819 
820 float HistoryBigram::scoreWithCode(
821  const WordNode *prev, const WordNode *cur,
822  const ValidationCodeExtractor &extractor) const {
823  return scoreWithCode(
824  {prev ? prev->word() : "", extractor && prev ? extractor(prev) : ""},
825  {cur ? cur->word() : "", extractor && cur ? extractor(cur) : ""});
826 }
827 
828 void HistoryBigram::addWithContext(const std::vector<WordWithCode> &context,
829  std::vector<WordWithCode> newSentence) {
830  FCITX_D();
831  if (context.empty() ||
832  !d->pools_[0].maybeAppendToLatestSentence(context, newSentence)) {
833  addWithCode(newSentence);
834  }
835 }
836 
837 } // namespace libime
int32_t rawUnigramFrequency(WordWithCodeView word) const
Query the raw frequency of the unigram.
float bigramFrequency(WordWithCodeView prev, WordWithCodeView cur) const
Query the weighted frequency of the bigram.
Provide a DATrie implementation.
float unigramFrequency(WordWithCodeView word) const
Query the weighted frequency of the unigram.
void fillPredict(std::unordered_set< std::string > &words, const std::vector< std::string > &sentence, size_t maxSize) const
Fill the prediction based on current sentence.
int32_t rawBigramFrequency(WordWithCodeView prev, WordWithCodeView cur) const
Query the raw frequency of the bigram.