From 3601c214f80cf62eecacd84b2fb27fe9c6b14a19 Mon Sep 17 00:00:00 2001 From: Keisuke Kuroyanagi Date: Wed, 15 Oct 2014 20:43:27 +0900 Subject: [PATCH] Update useless n-gram entry detection logic during GC. Bug: 14425059 Change-Id: Ib939deae5b60167751dee07965bb1ef1a43c4625 --- .../content/language_model_dict_content.cpp | 45 ++++++++------ .../v4/content/language_model_dict_content.h | 8 +-- .../latin/BinaryDictionaryDecayingTests.java | 61 ++++++++++++++++++- 3 files changed, 92 insertions(+), 22 deletions(-) diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index a7296a302..c4297f5d6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -270,16 +270,26 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord } bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, - const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { + const int prevWordCount, const HeaderPolicy *const headerPolicy, + int *const outEntryCounts) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { - if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { - AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", - level, MAX_PREV_WORD_COUNT_FOR_N_GRAM); + if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", + prevWordCount, MAX_PREV_WORD_COUNT_FOR_N_GRAM); return false; } const ProbabilityEntry probabilityEntry = ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); - if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) { + if (prevWordCount > 0 && probabilityEntry.isValid() + && !mTrieMap.getRoot(entry.key()).mIsValid) { + // The entry is related to a word that has been removed. Remove the entry. + if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { + return false; + } + continue; + } + if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence() + && probabilityEntry.isValid()) { const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( probabilityEntry.getHistoricalInfo(), headerPolicy); if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { @@ -298,13 +308,13 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b } } if (!probabilityEntry.representsBeginningOfSentence()) { - outEntryCounts[level] += 1; + outEntryCounts[prevWordCount] += 1; } if (!entry.hasNextLevelMap()) { continue; } - if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), level + 1, - headerPolicy, outEntryCounts)) { + if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), + prevWordCount + 1, headerPolicy, outEntryCounts)) { return false; } } @@ -332,7 +342,7 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( for (int i = 0; i < entryCountToRemove; ++i) { const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; if (!removeNgramProbabilityEntry( - WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) { + WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) { return false; } } @@ -342,9 +352,9 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, const int bitmapEntryIndex, std::vector *const prevWordIds, std::vector *const outEntryInfo) const { - const int currentLevel = prevWordIds->size(); + const int prevWordCount = prevWordIds->size(); for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { - if (currentLevel < targetLevel) { + if (prevWordCount < targetLevel) { if (!entry.hasNextLevelMap()) { continue; } @@ -379,10 +389,10 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( if (left.mKey != right.mKey) { return left.mKey < right.mKey; } - if (left.mEntryLevel != right.mEntryLevel) { - return left.mEntryLevel > right.mEntryLevel; + if (left.mPrevWordCount != right.mPrevWordCount) { + return left.mPrevWordCount > right.mPrevWordCount; } - for (int i = 0; i < left.mEntryLevel; ++i) { + for (int i = 0; i < left.mPrevWordCount; ++i) { if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) { return left.mPrevWordIds[i] < right.mPrevWordIds[i]; } @@ -392,9 +402,10 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( } LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, - const int timestamp, const int key, const int entryLevel, const int *const prevWordIds) - : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) { - memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0])); + const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds) + : mProbability(probability), mTimestamp(timestamp), mKey(key), + mPrevWordCount(prevWordCount) { + memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 834cf933d..51ef090e1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -160,7 +160,7 @@ class LanguageModelDictContent { outEntryCounts[i] = 0; } return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), - 0 /* level */, headerPolicy, outEntryCounts); + 0 /* prevWordCount */, headerPolicy, outEntryCounts); } // entryCounts should be created by updateAllProbabilityEntries. @@ -185,12 +185,12 @@ class LanguageModelDictContent { }; EntryInfoToTurncate(const int probability, const int timestamp, const int key, - const int entryLevel, const int *const prevWordIds); + const int prevWordCount, const int *const prevWordIds); int mProbability; int mTimestamp; int mKey; - int mEntryLevel; + int mPrevWordCount; int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; private: @@ -208,7 +208,7 @@ class LanguageModelDictContent { int *const outNgramCount); int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; - bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int level, + bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, const HeaderPolicy *const headerPolicy, int *const outEntryCounts); bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, int *const outEntryCount); diff --git a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java index 0e58b7211..fa70f9988 100644 --- a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java +++ b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java @@ -75,6 +75,10 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { return formatVersion > FormatSpec.VERSION401; } + private static boolean supportsNgram(final int formatVersion) { + return formatVersion >= FormatSpec.VERSION4_DEV; + } + private void onInputWord(final BinaryDictionary binaryDictionary, final String word, final boolean isValidWord) { binaryDictionary.updateEntriesForWordWithNgramContext(NgramContext.EMPTY_PREV_WORDS_INFO, @@ -88,6 +92,14 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { mCurrentTime /* timestamp */); } + private void onInputWordWithPrevWords(final BinaryDictionary binaryDictionary, + final String word, final boolean isValidWord, final String prevWord, + final String prevPrevWord) { + binaryDictionary.updateEntriesForWordWithNgramContext( + new NgramContext(new WordInfo(prevWord), new WordInfo(prevPrevWord)), word, + isValidWord, 1 /* count */, mCurrentTime /* timestamp */); + } + private void onInputWordWithBeginningOfSentenceContext( final BinaryDictionary binaryDictionary, final String word, final boolean isValidWord) { binaryDictionary.updateEntriesForWordWithNgramContext(NgramContext.BEGINNING_OF_SENTENCE, @@ -99,6 +111,12 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { return binaryDictionary.isValidNgram(new NgramContext(new WordInfo(word0)), word1); } + private static boolean isValidTrigram(final BinaryDictionary binaryDictionary, + final String word0, final String word1, final String word2) { + return binaryDictionary.isValidNgram( + new NgramContext(new WordInfo(word1), new WordInfo(word0)), word2); + } + private void forcePassingShortTime(final BinaryDictionary binaryDictionary) { // 30 days. final int timeToElapse = (int)TimeUnit.SECONDS.convert(30, TimeUnit.DAYS); @@ -256,7 +274,23 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { onInputWordWithPrevWord(binaryDictionary, "y", true /* isValidWord */, "x"); assertFalse(isValidBigram(binaryDictionary, "x", "y")); - binaryDictionary.close(); + if (!supportsNgram(formatVersion)) { + return; + } + + onInputWordWithPrevWords(binaryDictionary, "c", false /* isValidWord */, "b", "a"); + assertFalse(isValidTrigram(binaryDictionary, "a", "b", "c")); + assertFalse(isValidBigram(binaryDictionary, "b", "c")); + onInputWordWithPrevWords(binaryDictionary, "c", false /* isValidWord */, "b", "a"); + assertTrue(isValidTrigram(binaryDictionary, "a", "b", "c")); + assertTrue(isValidBigram(binaryDictionary, "b", "c")); + + onInputWordWithPrevWords(binaryDictionary, "d", true /* isValidWord */, "b", "a"); + assertTrue(isValidTrigram(binaryDictionary, "a", "b", "d")); + assertTrue(isValidBigram(binaryDictionary, "b", "d")); + + onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "b", "a"); + assertTrue(isValidTrigram(binaryDictionary, "a", "b", "cd")); } public void testDecayingProbability() { @@ -301,6 +335,31 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { forcePassingLongTime(binaryDictionary); assertFalse(isValidBigram(binaryDictionary, "a", "b")); + if (!supportsNgram(formatVersion)) { + return; + } + + onInputWord(binaryDictionary, "ab", true /* isValidWord */); + onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab"); + onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab"); + assertTrue(isValidTrigram(binaryDictionary, "ab", "bc", "cd")); + forcePassingShortTime(binaryDictionary); + assertFalse(isValidTrigram(binaryDictionary, "ab", "bc", "cd")); + + onInputWord(binaryDictionary, "ab", true /* isValidWord */); + onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab"); + onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab"); + onInputWord(binaryDictionary, "ab", true /* isValidWord */); + onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab"); + onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab"); + onInputWord(binaryDictionary, "ab", true /* isValidWord */); + onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab"); + onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab"); + forcePassingShortTime(binaryDictionary); + assertTrue(isValidTrigram(binaryDictionary, "ab", "bc", "cd")); + forcePassingLongTime(binaryDictionary); + assertFalse(isValidTrigram(binaryDictionary, "ab", "bc", "cd")); + binaryDictionary.close(); }