diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 88982e540..df3daa816 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -354,7 +354,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } bool addedNewBigram = false; const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); - if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos), + if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos, bigramProperty, &addedNewBigram)) { if (addedNewBigram) { mBigramCount++; @@ -396,7 +396,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor } const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); if (mUpdatingHelper.removeNgramEntry( - PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) { + PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) { mBigramCount--; return true; } else { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index d5749e9eb..f54bb151a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -38,6 +38,40 @@ bool LanguageModelDictContent::runGC( 0 /* nextLevelBitmapEntryIndex */, outNgramCount); } +int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds, + const int wordId) const { + int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); + int maxLevel = 0; + for (size_t i = 0; i < prevWordIds.size(); ++i) { + const int nextBitmapEntryIndex = + mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex; + if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) { + break; + } + maxLevel = i + 1; + bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; + } + + for (int i = maxLevel; i >= 0; --i) { + const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); + if (!result.mIsValid) { + continue; + } + const int probability = + ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability(); + if (mHasHistoricalInfo) { + return std::min( + probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */), + MAX_PROBABILITY); + } else { + return probability; + } + } + // Cannot find the word. + return NOT_A_PROBABILITY; +} + ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( const WordIdArrayView prevWordIds, const int wordId) const { const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index aa612e35a..4e0b47036 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -128,6 +128,8 @@ class LanguageModelDictContent { const LanguageModelDictContent *const originalContent, int *const outNgramCount); + int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const; + ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 6de3e5a81..308c35585 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -115,24 +115,12 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId, MultiBigramMap *const multiBigramMap) const { - // TODO: Quit using MultiBigramMap. if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } - const int ptNodePos = - mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); - const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); - if (multiBigramMap) { - return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds, - wordId, ptNodeParams.getProbability()); - } - if (prevWordIds) { - const int probability = getProbabilityOfWord(prevWordIds, wordId); - if (probability != NOT_A_PROBABILITY) { - return probability; - } - } - return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); + // TODO: Support n-gram. + return mBuffers->getLanguageModelDictContent()->getWordProbability( + WordIdArrayView::singleElementView(prevWordIds), wordId); } int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, @@ -166,7 +154,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, // TODO: Support n-gram. const ProbabilityEntry probabilityEntry = mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry( - IntArrayView::fromObject(prevWordIds), wordId); + IntArrayView::singleElementView(prevWordIds), wordId); if (!probabilityEntry.isValid()) { return NOT_A_PROBABILITY; } @@ -194,7 +182,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds, // TODO: Support n-gram. const auto languageModelDictContent = mBuffers->getLanguageModelDictContent(); for (const auto entry : languageModelDictContent->getProbabilityEntries( - WordIdArrayView::fromObject(prevWordIds))) { + WordIdArrayView::singleElementView(prevWordIds))) { const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); const int probability = probabilityEntry.hasHistoricalInfo() ? ForgettingCurveUtils::decodeProbability( @@ -511,7 +499,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( // Fetch bigram information. // TODO: Support n-gram. std::vector bigrams; - const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId); + const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId); int bigramWord1CodePoints[MAX_WORD_LENGTH]; for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries( prevWordIds)) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h index 9910777b8..313eb6b64 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h @@ -48,6 +48,11 @@ class ForgettingCurveUtils { static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount, const int bigramCount, const HeaderPolicy *const headerPolicy); + // TODO: Improve probability computation method and remove this. + static int getProbabilityBiasForNgram(const int n) { + return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE; + } + AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) { return static_cast(static_cast(maxUnigramCount) * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT); diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h index c9c3b21d4..08256bdef 100644 --- a/native/jni/src/utils/int_array_view.h +++ b/native/jni/src/utils/int_array_view.h @@ -61,9 +61,9 @@ class IntArrayView { return IntArrayView(array, N); } - // Returns a view that points one int object. Does not take ownership of the given object. - AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) { - return IntArrayView(object, 1); + // Returns a view that points one int object. + AK_FORCE_INLINE static IntArrayView singleElementView(const int *const ptr) { + return IntArrayView(ptr, 1); } AK_FORCE_INLINE int operator[](const size_t index) const { diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp index ca8d56f27..e6f0353e3 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp @@ -26,28 +26,28 @@ namespace latinime { namespace { TEST(LanguageModelDictContentTest, TestUnigramProbability) { - LanguageModelDictContent LanguageModelDictContent(false /* useHistoricalInfo */); + LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */); const int flag = 0xFF; const int probability = 10; const int wordId = 100; const ProbabilityEntry probabilityEntry(flag, probability); - LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); + languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); const ProbabilityEntry entry = - LanguageModelDictContent.getProbabilityEntry(wordId); + languageModelDictContent.getProbabilityEntry(wordId); EXPECT_EQ(flag, entry.getFlags()); EXPECT_EQ(probability, entry.getProbability()); // Remove - EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId)); - EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid()); - EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId)); - EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry)); - EXPECT_TRUE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid()); + EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId)); + EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid()); + EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId)); + EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry)); + EXPECT_TRUE(languageModelDictContent.getProbabilityEntry(wordId).isValid()); } TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) { - LanguageModelDictContent LanguageModelDictContent(true /* useHistoricalInfo */); + LanguageModelDictContent languageModelDictContent(true /* useHistoricalInfo */); const int flag = 0xF0; const int timestamp = 0x3FFFFFFF; @@ -56,19 +56,19 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) { const int wordId = 100; const HistoricalInfo historicalInfo(timestamp, level, count); const ProbabilityEntry probabilityEntry(flag, &historicalInfo); - LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); - const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId); + languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); + const ProbabilityEntry entry = languageModelDictContent.getProbabilityEntry(wordId); EXPECT_EQ(flag, entry.getFlags()); EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp()); EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel()); EXPECT_EQ(count, entry.getHistoricalInfo()->getCount()); // Remove - EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId)); - EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid()); - EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId)); - EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry)); - EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId)); + EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId)); + EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid()); + EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId)); + EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry)); + EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId)); } TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) { @@ -89,5 +89,31 @@ TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) { EXPECT_TRUE(wordIdSet.empty()); } +TEST(LanguageModelDictContentTest, TestGetWordProbability) { + LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */); + + const int flag = 0xFF; + const int probability = 10; + const int bigramProbability = 20; + const int trigramProbability = 30; + const int wordId = 100; + const int prevWordIdArray[] = { 1, 2 }; + const WordIdArrayView prevWordIds = WordIdArrayView::fromFixedSizeArray(prevWordIdArray); + + const ProbabilityEntry probabilityEntry(flag, probability); + languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); + const ProbabilityEntry bigramProbabilityEntry(flag, bigramProbability); + languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry); + languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId, + &bigramProbabilityEntry); + EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId)); + const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability); + languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), + prevWordIds[1], &probabilityEntry); + languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId, + &trigramProbabilityEntry); + EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId)); +} + } // namespace } // namespace latinime diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp index 161df2f43..93bad5822 100644 --- a/native/jni/tests/utils/int_array_view_test.cpp +++ b/native/jni/tests/utils/int_array_view_test.cpp @@ -52,7 +52,7 @@ TEST(IntArrayViewTest, TestConstructFromArray) { TEST(IntArrayViewTest, TestConstructFromObject) { const int object = 10; - const auto intArrayView = IntArrayView::fromObject(&object); + const auto intArrayView = IntArrayView::singleElementView(&object); EXPECT_EQ(1u, intArrayView.size()); EXPECT_EQ(object, intArrayView[0]); }