diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index a66cfef76..ea2d24e67 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -16,6 +16,9 @@ #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" +#include +#include + #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" namespace latinime { @@ -68,6 +71,19 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView return mTrieMap.remove(wordId, bitmapEntryIndex); } +bool LanguageModelDictContent::truncateEntries(const int *const entryCounts, + const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) { + for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { + if (entryCounts[i] <= maxEntryCounts[i]) { + continue; + } + if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) { + return false; + } + } + return true; +} + bool LanguageModelDictContent::runGCInner( const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, const TrieMap::TrieMapRange trieMapRange, @@ -162,4 +178,87 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap return true; } +bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( + const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) { + std::vector prevWordIds; + std::vector entryInfoVector; + if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(), + &prevWordIds, &entryInfoVector)) { + return false; + } + if (static_cast(entryInfoVector.size()) <= maxEntryCount) { + return true; + } + const int entryCountToRemove = static_cast(entryInfoVector.size()) - maxEntryCount; + std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove, + entryInfoVector.end(), + EntryInfoToTurncate::Comparator()); + for (int i = 0; i < entryCountToRemove; ++i) { + const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; + if (!removeNgramProbabilityEntry( + WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) { + return false; + } + } + return true; +} + +bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy, + const int targetLevel, const int bitmapEntryIndex, std::vector *const prevWordIds, + std::vector *const outEntryInfo) const { + const int currentLevel = prevWordIds->size(); + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + if (currentLevel < targetLevel) { + if (!entry.hasNextLevelMap()) { + continue; + } + prevWordIds->push_back(entry.key()); + if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(), + prevWordIds, outEntryInfo)) { + return false; + } + prevWordIds->pop_back(); + continue; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + const int probability = (mHasHistoricalInfo) ? + ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), + headerPolicy) : probabilityEntry.getProbability(); + outEntryInfo->emplace_back(probability, + probabilityEntry.getHistoricalInfo()->getTimeStamp(), + entry.key(), targetLevel, prevWordIds->data()); + } + return true; +} + +bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( + const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { + if (left.mProbability != right.mProbability) { + return left.mProbability < right.mProbability; + } + if (left.mTimestamp != right.mTimestamp) { + return left.mTimestamp > right.mTimestamp; + } + if (left.mKey != right.mKey) { + return left.mKey < right.mKey; + } + if (left.mEntryLevel != right.mEntryLevel) { + return left.mEntryLevel > right.mEntryLevel; + } + for (int i = 0; i < left.mEntryLevel; ++i) { + if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) { + return left.mPrevWordIds[i] < right.mPrevWordIds[i]; + } + } + // left and rigth represent the same entry. + return false; +} + +LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, + const int timestamp, const int key, const int entryLevel, const int *const prevWordIds) + : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) { + memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0])); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 31ee2fe24..43b2aab66 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -18,6 +18,7 @@ #define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H #include +#include #include "defines.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" @@ -77,13 +78,43 @@ class LanguageModelDictContent { bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { + for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { + outEntryCounts[i] = 0; + } return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */, headerPolicy, outEntryCounts); } + // entryCounts should be created by updateAllProbabilityEntries. + bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts, + const HeaderPolicy *const headerPolicy); + private: DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); + class EntryInfoToTurncate { + public: + class Comparator { + public: + bool operator()(const EntryInfoToTurncate &left, + const EntryInfoToTurncate &right) const; + private: + DISALLOW_ASSIGNMENT_OPERATOR(Comparator); + }; + + EntryInfoToTurncate(const int probability, const int timestamp, const int key, + const int entryLevel, const int *const prevWordIds); + + int mProbability; + int mTimestamp; + int mKey; + int mEntryLevel; + int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); + }; + TrieMap mTrieMap; const bool mHasHistoricalInfo; @@ -94,6 +125,11 @@ class LanguageModelDictContent { int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts); + bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, + const int maxEntryCount, const int targetLevel); + bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, + const int bitmapEntryIndex, std::vector *const prevWordIds, + std::vector *const outEntryInfo) const; }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 35bc44b8f..d53575aa7 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -91,6 +91,21 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, AKLOGE("Failed to update probabilities in language model dict content."); return false; } + if (headerPolicy->isDecayingDict()) { + int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount(); + maxEntryCountTable[1] = headerPolicy->getMaxBigramCount(); + for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) { + // TODO: Have max n-gram count. + maxEntryCountTable[i] = headerPolicy->getMaxBigramCount(); + } + if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable, + maxEntryCountTable, headerPolicy)) { + AKLOGE("Failed to truncate entries in language model dict content."); + return false; + } + } + DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); DynamicPtGcEventListeners @@ -193,6 +208,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return true; } +// TODO: Remove. bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( const Ver4PatriciaTrieNodeReader *const ptNodeReader, Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) { @@ -233,6 +249,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( return true; } +// TODO: Remove. bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) { const TerminalPositionLookupTable *const terminalPosLookupTable = mBuffers->getTerminalPositionLookupTable();