diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index f54bb151a..0675de6fa 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -39,7 +39,7 @@ bool LanguageModelDictContent::runGC( } int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds, - const int wordId) const { + const int wordId, const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); int maxLevel = 0; @@ -58,14 +58,15 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI if (!result.mIsValid) { continue; } - const int probability = - ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability(); + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); if (mHasHistoricalInfo) { - return std::min( - probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */), - MAX_PROBABILITY); + const int probability = ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), headerPolicy) + + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */); + return std::min(probability, MAX_PROBABILITY); } else { - return probability; + return probabilityEntry.getProbability(); } } // Cannot find the word. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 4e0b47036..a793af4be 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -128,7 +128,8 @@ class LanguageModelDictContent { const LanguageModelDictContent *const originalContent, int *const outNgramCount); - int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const; + int getWordProbability(const WordIdArrayView prevWordIds, const int wordId, + const HeaderPolicy *const headerPolicy) const; ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index e624bf338..1336a6229 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -121,9 +121,10 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); // TODO: Support n-gram. - return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability( - prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(), - ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0); + const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability( + prevWordIds.limit(1 /* maxSize */), wordId, mHeaderPolicy); + return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + probability == 0); } int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp index 7608b45c2..c5849d054 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp @@ -107,13 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) { languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry); languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId, &bigramProbabilityEntry); - EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId)); + EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId, + nullptr /* headerPolicy */)); const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability); languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), prevWordIds[1], &probabilityEntry); languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId, &trigramProbabilityEntry); - EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId)); + EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId, + nullptr /* headerPolicy */)); } } // namespace