diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index 3ac424fea..a3d8ec158 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -335,8 +335,9 @@ static void latinime_BinaryDictionary_addUnigramWord(JNIEnv *env, jclass clazz, if (!shortcutTargetCodePoints.empty()) { shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); } + // Use 1 for count to indicate the word has inputed. const UnigramProperty unigramProperty(isNotAWord, isBlacklisted, - probability, timestamp, 0 /* level */, 0 /* count */, &shortcuts); + probability, timestamp, 0 /* level */, 1 /* count */, &shortcuts); dictionary->addUnigramWord(codePoints, codePointCount, &unigramProperty); } @@ -436,8 +437,9 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j env->GetIntField(languageModelParam, shortcutProbabilityFieldId); shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); } + // Use 1 for count to indicate the word has inputed. const UnigramProperty unigramProperty(isNotAWord, isBlacklisted, - unigramProbability, timestamp, 0 /* level */, 0 /* count */, &shortcuts); + unigramProbability, timestamp, 0 /* level */, 1 /* count */, &shortcuts); dictionary->addUnigramWord(word1CodePoints, word1Length, &unigramProperty); if (word0) { jint bigramProbability = env->GetIntField(languageModelParam, bigramProbabilityFieldId); diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.cpp index 4975512ff..1645039d3 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.cpp @@ -257,10 +257,12 @@ const BigramEntry Ver4BigramListPolicy::createUpdatedBigramEntryFrom( const int timestamp) const { // TODO: Consolidate historical info and probability. if (mHeaderPolicy->hasHistoricalInfoOfWords()) { + // Use 1 for count to indicate the bigram has inputed. + const HistoricalInfo historicalInfoForUpdate(timestamp, 0 /* level */, 1 /* count */); const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( - originalBigramEntry->getHistoricalInfo(), newProbability, timestamp, - mHeaderPolicy); + originalBigramEntry->getHistoricalInfo(), newProbability, + &historicalInfoForUpdate, mHeaderPolicy); return originalBigramEntry->updateHistoricalInfoAndGetEntry(&updatedHistoricalInfo); } else { return originalBigramEntry->updateProbabilityAndGetEntry(newProbability); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 50a3e56e3..cc3a24a22 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -387,11 +387,12 @@ const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const UnigramProperty *const unigramProperty) const { // TODO: Consolidate historical info and probability. if (mHeaderPolicy->hasHistoricalInfoOfWords()) { + const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(), + unigramProperty->getLevel(), unigramProperty->getCount()); const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( originalProbabilityEntry->getHistoricalInfo(), - unigramProperty->getProbability(), unigramProperty->getTimestamp(), - mHeaderPolicy); + unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy); return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo( &updatedHistoricalInfo); } else { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 2584fe5b7..9999e0692 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -425,6 +425,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code } int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) { + // TODO: Return code point count like other methods. + // Null termination. + outCodePoints[0] = 0; if (token == 0) { mTerminalPtNodePositionsForIteratingWords.clear(); DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy( @@ -441,8 +444,13 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; int unigramProbability = NOT_A_PROBABILITY; - getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH, - outCodePoints, &unigramProbability); + const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( + terminalPtNodePos, MAX_WORD_LENGTH, outCodePoints, &unigramProbability); + if (codePointCount < MAX_WORD_LENGTH) { + // Null termination. outCodePoints have to be null terminated or contain MAX_WORD_LENGTH + // code points. + outCodePoints[codePointCount] = 0; + } const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp index c7d3df984..fed0ae77e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp @@ -30,7 +30,7 @@ const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8; const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60; const int ForgettingCurveUtils::MAX_LEVEL = 3; -const int ForgettingCurveUtils::MIN_VALID_LEVEL = 1; +const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1; const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15; const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14; @@ -41,25 +41,34 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT // TODO: Revise the logic to decide the initial probability depending on the given probability. /* static */ const HistoricalInfo ForgettingCurveUtils::createUpdatedHistoricalInfo( - const HistoricalInfo *const originalHistoricalInfo, - const int newProbability, const int timestamp, const HeaderPolicy *const headerPolicy) { + const HistoricalInfo *const originalHistoricalInfo, const int newProbability, + const HistoricalInfo *const newHistoricalInfo, const HeaderPolicy *const headerPolicy) { + const int timestamp = newHistoricalInfo->getTimeStamp(); if (newProbability != NOT_A_PROBABILITY && originalHistoricalInfo->getLevel() == 0) { - return HistoricalInfo(timestamp, MIN_VALID_LEVEL /* level */, 0 /* count */); - } else if (!originalHistoricalInfo->isValid()) { + // Add entry as a valid word. + const int level = clampToVisibleEntryLevelRange(newHistoricalInfo->getLevel()); + const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy); + return HistoricalInfo(timestamp, level, count); + } else if (!originalHistoricalInfo->isValid() + || originalHistoricalInfo->getLevel() < newHistoricalInfo->getLevel() + || (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel() + && originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) { // Initial information. - return HistoricalInfo(timestamp, 0 /* level */, 1 /* count */); + const int level = clampToValidLevelRange(newHistoricalInfo->getLevel()); + const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy); + return HistoricalInfo(timestamp, level, count); } else { const int updatedCount = originalHistoricalInfo->getCount() + 1; if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) { // The count exceeds the max value the level can be incremented. if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) { // The level is already max. - return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), - originalHistoricalInfo->getCount()); + return HistoricalInfo(timestamp, + originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount()); } else { // Level up. - return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel() + 1, - 0 /* count */); + return HistoricalInfo(timestamp, + originalHistoricalInfo->getLevel() + 1, 0 /* count */); } } else { return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), updatedCount); @@ -73,8 +82,8 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT headerPolicy->getForgettingCurveDurationToLevelDown()); return sProbabilityTable.getProbability( headerPolicy->getForgettingCurveProbabilityValuesTableId(), - std::min(std::max(historicalInfo->getLevel(), 0), MAX_LEVEL), - std::min(std::max(elapsedTimeStepCount, 0), MAX_ELAPSED_TIME_STEP_COUNT)); + clampToValidLevelRange(historicalInfo->getLevel()), + clampToValidTimeStepCountRange(elapsedTimeStepCount)); } /* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability, @@ -155,6 +164,23 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT return elapsedTimeInSeconds / timeStepDurationInSeconds; } +/* static */ int ForgettingCurveUtils::clampToVisibleEntryLevelRange(const int level) { + return std::min(std::max(level, MIN_VISIBLE_LEVEL), MAX_LEVEL); +} + +/* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count, + const HeaderPolicy *const headerPolicy) { + return std::min(std::max(count, 0), headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1); +} + +/* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) { + return std::min(std::max(level, 0), MAX_LEVEL); +} + +/* static */ int ForgettingCurveUtils::clampToValidTimeStepCountRange(const int timeStepCount) { + return std::min(std::max(timeStepCount, 0), MAX_ELAPSED_TIME_STEP_COUNT); +} + const int ForgettingCurveUtils::ProbabilityTable::PROBABILITY_TABLE_COUNT = 4; const int ForgettingCurveUtils::ProbabilityTable::WEAK_PROBABILITY_TABLE_ID = 0; const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID = 1; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h index bb8690939..3ff80aeec 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h @@ -30,7 +30,7 @@ class ForgettingCurveUtils { public: static const HistoricalInfo createUpdatedHistoricalInfo( const HistoricalInfo *const originalHistoricalInfo, const int newProbability, - const int timestamp, const HeaderPolicy *const headerPolicy); + const HistoricalInfo *const newHistoricalInfo, const HeaderPolicy *const headerPolicy); static const HistoricalInfo createHistoricalInfoToSave( const HistoricalInfo *const originalHistoricalInfo, @@ -93,7 +93,7 @@ class ForgettingCurveUtils { static const int DECAY_INTERVAL_SECONDS; static const int MAX_LEVEL; - static const int MIN_VALID_LEVEL; + static const int MIN_VISIBLE_LEVEL; static const int MAX_ELAPSED_TIME_STEP_COUNT; static const int DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD; @@ -103,8 +103,11 @@ class ForgettingCurveUtils { static const ProbabilityTable sProbabilityTable; static int backoff(const int unigramProbability); - static int getElapsedTimeStepCount(const int timestamp, const int durationToLevelDown); + static int clampToVisibleEntryLevelRange(const int level); + static int clampToValidLevelRange(const int level); + static int clampToValidCountRange(const int count, const HeaderPolicy *const headerPolicy); + static int clampToValidTimeStepCountRange(const int timeStepCount); }; } // namespace latinime #endif /* LATINIME_FORGETTING_CURVE_UTILS_H */ diff --git a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java index ae2205b36..aed24c56e 100644 --- a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java +++ b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java @@ -93,15 +93,17 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { private File createEmptyDictionaryAndGetFile(final String dictId, final int formatVersion) throws IOException { - if (formatVersion == FormatSpec.VERSION4) { - return createEmptyVer4DictionaryAndGetFile(dictId); + if (formatVersion == FormatSpec.VERSION4 + || formatVersion == FormatSpec.VERSION4_ONLY_FOR_TESTING) { + return createEmptyVer4DictionaryAndGetFile(dictId, formatVersion); } else { throw new IOException("Dictionary format version " + formatVersion + " is not supported."); } } - private File createEmptyVer4DictionaryAndGetFile(final String dictId) throws IOException { + private File createEmptyVer4DictionaryAndGetFile(final String dictId, final int formatVersion) + throws IOException { final File file = File.createTempFile(dictId, TEST_DICT_FILE_EXTENSION, getContext().getCacheDir()); FileUtils.deleteRecursively(file); @@ -113,7 +115,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { DictionaryHeader.ATTRIBUTE_VALUE_TRUE); attributeMap.put(DictionaryHeader.HAS_HISTORICAL_INFO_KEY, DictionaryHeader.ATTRIBUTE_VALUE_TRUE); - if (BinaryDictionaryUtils.createEmptyDictFile(file.getAbsolutePath(), FormatSpec.VERSION4, + if (BinaryDictionaryUtils.createEmptyDictFile(file.getAbsolutePath(), formatVersion, LocaleUtils.constructLocaleFromString(TEST_LOCALE), attributeMap)) { return file; } else { @@ -562,4 +564,43 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { } } } + + public void testDictMigration() { + testDictMigration(FormatSpec.VERSION4_ONLY_FOR_TESTING, FormatSpec.VERSION4); + } + + private void testDictMigration(final int fromFormatVersion, final int toFormatVersion) { + setCurrentTimeForTestMode(mCurrentTime); + File dictFile = null; + try { + dictFile = createEmptyDictionaryAndGetFile("TestBinaryDictionary", fromFormatVersion); + } catch (IOException e) { + fail("IOException while writing an initial dictionary : " + e); + } + final BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(), + 0 /* offset */, dictFile.length(), true /* useFullEditDistance */, + Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */); + // TODO: Add tests for bigrams when the implementation gets ready. + addUnigramWord(binaryDictionary, "aaa", DUMMY_PROBABILITY); + assertTrue(binaryDictionary.isValidWord("aaa")); + addUnigramWord(binaryDictionary, "bbb", Dictionary.NOT_A_PROBABILITY); + assertFalse(binaryDictionary.isValidWord("bbb")); + addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY); + addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY); + addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY); + addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY); + addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY); + + assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion()); + assertTrue(binaryDictionary.migrateTo(toFormatVersion)); + assertTrue(binaryDictionary.isValidDictionary()); + assertEquals(toFormatVersion, binaryDictionary.getFormatVersion()); + assertTrue(binaryDictionary.isValidWord("aaa")); + assertFalse(binaryDictionary.isValidWord("bbb")); + assertTrue(binaryDictionary.getFrequency("aaa") < binaryDictionary.getFrequency("ccc")); + addUnigramWord(binaryDictionary, "bbb", Dictionary.NOT_A_PROBABILITY); + assertTrue(binaryDictionary.isValidWord("bbb")); + binaryDictionary.close(); + dictFile.delete(); + } }