diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index 61435c8d0..a7296a302 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -43,18 +43,18 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr const int wordId, const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); - int maxLevel = 0; + int maxPrevWordCount = 0; for (size_t i = 0; i < prevWordIds.size(); ++i) { const int nextBitmapEntryIndex = mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex; if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) { break; } - maxLevel = i + 1; + maxPrevWordCount = i + 1; bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; } - for (int i = maxLevel; i >= 0; --i) { + for (int i = maxPrevWordCount; i >= 0; --i) { const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); if (!result.mIsValid) { continue; @@ -69,9 +69,24 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr // The entry should not be treated as a valid entry. continue; } - probability = std::min(rawProbability - + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */), - MAX_PROBABILITY); + if (i == 0) { + // unigram + probability = rawProbability; + } else { + const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( + prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); + if (!prevWordProbabilityEntry.isValid()) { + continue; + } + if (prevWordProbabilityEntry.representsBeginningOfSentence()) { + probability = rawProbability; + } else { + const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability( + prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy); + probability = std::min(MAX_PROBABILITY - prevWordRawProbability + + rawProbability, MAX_PROBABILITY); + } + } } else { probability = probabilityEntry.getProbability(); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index fa1415633..f4d340f86 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -98,17 +98,17 @@ class ProbabilityEntry { } uint64_t encode(const bool hasHistoricalInfo) const { - uint64_t encodedEntry = static_cast(mFlags); + uint64_t encodedEntry = static_cast(mFlags); if (hasHistoricalInfo) { encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT)) - ^ static_cast(mHistoricalInfo.getTimestamp()); + | static_cast(mHistoricalInfo.getTimestamp()); encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT)) - ^ static_cast(mHistoricalInfo.getLevel()); + | static_cast(mHistoricalInfo.getLevel()); encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) - ^ static_cast(mHistoricalInfo.getCount()); + | static_cast(mHistoricalInfo.getCount()); } else { encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT)) - ^ static_cast(mProbability); + | static_cast(mProbability); } return encodedEntry; }