diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index bbcea2ee0..a66cfef76 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -16,6 +16,8 @@ #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" +#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" + namespace latinime { bool LanguageModelDictContent::save(FILE *const file) const { @@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord return bitmapEntryIndex; } +bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex, + const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", + level, MAX_PREV_WORD_COUNT_FOR_N_GRAM); + return false; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) { + const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( + probabilityEntry.getHistoricalInfo(), headerPolicy); + if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { + // Update the entry. + const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); + if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), + bitmapEntryIndex)) { + return false; + } + } else { + // Remove the entry. + if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { + return false; + } + continue; + } + } + if (!probabilityEntry.representsBeginningOfSentence()) { + outEntryCounts[level] += 1; + } + if (!entry.hasNextLevelMap()) { + continue; + } + if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1, + headerPolicy, outEntryCounts)) { + return false; + } + } + return true; +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index bd07f2f62..31ee2fe24 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -29,6 +29,8 @@ namespace latinime { +class HeaderPolicy; + /** * Class representing language model. * @@ -73,6 +75,12 @@ class LanguageModelDictContent { bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId); + bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy, + int *const outEntryCounts) { + return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */, + headerPolicy, outEntryCounts); + } + private: DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); @@ -84,6 +92,8 @@ class LanguageModelDictContent { int *const outNgramCount); int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; + bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, + const HeaderPolicy *const headerPolicy, int *const outEntryCounts); }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index fb6840ba6..b7c31bf75 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA const ProbabilityEntry originalProbabilityEntry = mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); - if (originalProbabilityEntry.hasHistoricalInfo()) { - const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( - originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy); - const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(), - &historicalInfo); - if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) { - AKLOGE("Cannot write updated probability entry. terminalId: %d", - toBeUpdatedPtNodeParams->getTerminalId()); - return false; - } - const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy); - if (!isValid) { - if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { - AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); - return false; - } - } - *outNeedsToKeepPtNode = isValid; - } else { - // No need to update probability. + if (originalProbabilityEntry.isValid()) { *outNeedsToKeepPtNode = true; + return true; } + if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { + AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); + return false; + } + *outNeedsToKeepPtNode = false; return true; } @@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } +// TODO: Move probability handling code to LanguageModelDictContent. const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const ProbabilityEntry *const originalProbabilityEntry, const ProbabilityEntry *const probabilityEntry) const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 4220312e0..35bc44b8f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy, &shortcutPolicy); + int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy, + entryCountTable)) { + AKLOGE("Failed to update probabilities in language model dict content."); + return false; + } DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); DynamicPtGcEventListeners diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h index 6d91790b2..c2aeac211 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h @@ -84,6 +84,10 @@ class TrieMap { return mValue; } + AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const { + return mNextLevelBitmapEntryIndex; + } + private: const TrieMap *const mTrieMap; const int mKey;