diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index e72226cec..461d1d859 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -629,8 +629,7 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j } } while (token != 0); - // Add bigrams. - // TODO: Support ngrams. + // Add ngrams. do { token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount); const WordProperty wordProperty = dictionary->getWordProperty( diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 66cb0513b..08e39ce43 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -580,10 +580,12 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH, bigramWord1CodePoints); const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); - const int probability = bigramEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - bigramEntry.getHistoricalInfo(), mHeaderPolicy) : - bigramEntry.getProbability(); + const int rawBigramProbability = bigramEntry.hasHistoricalInfo() + ? ForgettingCurveUtils::decodeProbability( + bigramEntry.getHistoricalInfo(), mHeaderPolicy) + : bigramEntry.getProbability(); + const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(), + ptNodeParams.representsBeginningOfSentence(), rawBigramProbability); ngrams.emplace_back( NgramContext(wordCodePoints.data(), wordCodePoints.size(), ptNodeParams.representsBeginningOfSentence()), diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index a88996524..b96290437 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -140,6 +140,44 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo); } +std::vector + LanguageModelDictContent::exportAllNgramEntriesRelatedToWord( + const HeaderPolicy *const headerPolicy, const int wordId) const { + const TrieMap::Result result = mTrieMap.getRoot(wordId); + if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) { + // The word doesn't have any related ngram entries. + return std::vector(); + } + std::vector prevWordIds = { wordId }; + std::vector entries; + exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex, + &prevWordIds, &entries); + return entries; +} + +void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner( + const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex, + std::vector *const prevWordIds, + std::vector *const outBummpedFullEntryInfo) const { + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + const int wordId = entry.key(); + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + if (probabilityEntry.isValid()) { + const WordAttributes wordAttributes = getWordAttributes( + WordIdArrayView(*prevWordIds), wordId, headerPolicy); + outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId, + wordAttributes, probabilityEntry); + } + if (entry.hasNextLevelMap()) { + prevWordIds->push_back(wordId); + exportAllNgramEntriesRelatedToWordInner(headerPolicy, + entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo); + prevWordIds->pop_back(); + } + } +} + bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCounts, const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters) { @@ -231,24 +269,25 @@ bool LanguageModelDictContent::runGCInner( } int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) { - if (prevWordIds.empty()) { - return mTrieMap.getRootBitmapEntryIndex(); - } - const int lastBitmapEntryIndex = - getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1)); - if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) { - return TrieMap::INVALID_INDEX; - } - const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID); - const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex); - if (!result.mIsValid) { - if (!mTrieMap.put(oldestPrevWordId, - ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) { - return TrieMap::INVALID_INDEX; + int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); + for (const int wordId : prevWordIds) { + const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex); + if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) { + lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex; + continue; } + if (!result.mIsValid) { + if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo), + lastBitmapEntryIndex)) { + AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId, + lastBitmapEntryIndex); + return TrieMap::INVALID_INDEX; + } + } + lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId, + lastBitmapEntryIndex); } - return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID), - lastBitmapEntryIndex); + return lastBitmapEntryIndex; } int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 41a429a54..1cccf92d2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -110,6 +110,27 @@ class LanguageModelDictContent { const bool mHasHistoricalInfo; }; + class DumppedFullEntryInfo { + public: + DumppedFullEntryInfo(std::vector &prevWordIds, const int targetWordId, + const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry) + : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId), + mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {} + + const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); } + int getTargetWordId() const { return mTargetWordId; } + const WordAttributes &getWordAttributes() const { return mWordAttributes; } + const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo); + + const std::vector mPrevWordIds; + const int mTargetWordId; + const WordAttributes mWordAttributes; + const ProbabilityEntry mProbabilityEntry; + }; + LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer, const bool hasHistoricalInfo) : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {} @@ -151,6 +172,9 @@ class LanguageModelDictContent { EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const; + std::vector exportAllNgramEntriesRelatedToWord( + const HeaderPolicy *const headerPolicy, const int wordId) const; + bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters) { return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), @@ -212,6 +236,9 @@ class LanguageModelDictContent { const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry, const bool isValid, const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const; + void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy, + const int bitmapEntryIndex, std::vector *const prevWordIds, + std::vector *const outBummpedFullEntryInfo) const; }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 1669375e2..193326d82 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -491,30 +491,37 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( const int ptNodePos = mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - const ProbabilityEntry probabilityEntry = - mBuffers->getLanguageModelDictContent()->getProbabilityEntry( - ptNodeParams.getTerminalId()); - const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); - // Fetch bigram information. - // TODO: Support n-gram. + const LanguageModelDictContent *const languageModelDictContent = + mBuffers->getLanguageModelDictContent(); + // Fetch ngram information. std::vector ngrams; - const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId); - int bigramWord1CodePoints[MAX_WORD_LENGTH]; - for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries( - prevWordIds)) { - const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(), - MAX_WORD_LENGTH, bigramWord1CodePoints); + int ngramTargetCodePoints[MAX_WORD_LENGTH]; + int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; + int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord( + mHeaderPolicy, wordId)) { + const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(), + MAX_WORD_LENGTH, ngramTargetCodePoints); + const WordIdArrayView prevWordIds = entry.getPrevWordIds(); + for (size_t i = 0; i < prevWordIds.size(); ++i) { + ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i], + MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]); + ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry( + prevWordIds[i]).representsBeginningOfSentence(); + if (ngramPrevWordIsBeginningOfSentense[i]) { + ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker( + ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]); + } + } + const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount, + ngramPrevWordIsBeginningOfSentense, prevWordIds.size()); const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry(); const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo(); - const int probability = ngramProbabilityEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) : - ngramProbabilityEntry.getProbability(); - ngrams.emplace_back( - NgramContext( - wordCodePoints.data(), wordCodePoints.size(), - probabilityEntry.representsBeginningOfSentence()), - CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(), - probability, *historicalInfo); + // TODO: Output flags in WordAttributes. + ngrams.emplace_back(ngramContext, + CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(), + entry.getWordAttributes().getProbability(), *historicalInfo); } // Fetch shortcut information. std::vector shortcuts; @@ -534,6 +541,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( shortcutProbability); } } + const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry( + ptNodeParams.getTerminalId()); + const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(), probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(), probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(), diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h index 5e9cdd9b2..7871c26ef 100644 --- a/native/jni/src/utils/char_utils.h +++ b/native/jni/src/utils/char_utils.h @@ -101,6 +101,17 @@ class CharUtils { return codePointCount + 1; } + // Returns updated code point count. + static AK_FORCE_INLINE int removeBeginningOfSentenceMarker(int *const codePoints, + const int codePointCount) { + if (codePointCount <= 0 || codePoints[0] != CODE_POINT_BEGINNING_OF_SENTENCE) { + return codePointCount; + } + const int newCodePointCount = codePointCount - 1; + memmove(codePoints, codePoints + 1, sizeof(int) * newCodePointCount); + return newCodePointCount; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils); diff --git a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java index 83e523c3c..9eb03e69a 100644 --- a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java +++ b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java @@ -653,6 +653,13 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { assertFalse(binaryDictionary.isValidWord("bbb")); assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb")); + if (supportsNgram(toFormatVersion)) { + onInputWordWithPrevWords(binaryDictionary, "xyz", true, "abc", "aaa"); + assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz")); + onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa"); + assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def")); + } + assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion()); assertTrue(binaryDictionary.migrateTo(toFormatVersion)); assertTrue(binaryDictionary.isValidDictionary()); @@ -666,6 +673,14 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb")); onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa"); assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb")); + + if (supportsNgram(toFormatVersion)) { + assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz")); + assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def")); + onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa"); + assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "def")); + } + binaryDictionary.close(); }