diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index b60499e9f..10f90523a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -121,8 +121,14 @@ bool Ver4PatriciaTrieNodeWriter::markPtNodeAsWillBecomeNonTerminal( const PatriciaTrieReadingUtils::NodeFlags updatedFlags = DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */, false /* isDeleted */, true /* willBecomeNonTerminal */); - int writingPos = toBeUpdatedPtNodeParams->getHeadPos(); + if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition( + toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */)) { + AKLOGE("Cannot update terminal position lookup table. terminal id: %d", + toBeUpdatedPtNodeParams->getTerminalId()); + return false; + } // Update flags. + int writingPos = toBeUpdatedPtNodeParams->getHeadPos(); return DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags, &writingPos); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 21d009ecb..77fb41dc5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -17,6 +17,7 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h" #include +#include #include "suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" @@ -97,10 +98,16 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) { return false; } + const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + .getValidUnigramCount(); if (headerPolicy->isDecayingDict() - && traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted - .getValidUnigramCount() > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) { - // TODO: Remove more unigrams. + && unigramCount > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) { + if (!turncateUnigrams(&ptNodeReader, &ptNodeWriter, + ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC)) { + AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount, + ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC); + return false; + } } readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); @@ -179,6 +186,42 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return true; } +bool Ver4PatriciaTrieWritingHelper::turncateUnigrams( + const Ver4PatriciaTrieNodeReader *const ptNodeReader, + Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) { + const TerminalPositionLookupTable *const terminalPosLookupTable = + mBuffers->getTerminalPositionLookupTable(); + const int nextTerminalId = terminalPosLookupTable->getNextTerminalId(); + std::priority_queue, DictProbabilityComparator> + priorityQueue; + for (int i = 0; i < nextTerminalId; ++i) { + const int terminalPos = terminalPosLookupTable->getTerminalPtNodePosition(i); + if (terminalPos == NOT_A_DICT_POS) { + continue; + } + const ProbabilityEntry probabilityEntry = + mBuffers->getProbabilityDictContent()->getProbabilityEntry(i); + const int probability = probabilityEntry.hasHistoricalInfo() ? + ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo()) : + probabilityEntry.getProbability(); + priorityQueue.push(DictProbability(terminalPos, probability, + probabilityEntry.getHistoricalInfo()->getTimeStamp())); + } + + // Delete unigrams. + while (static_cast(priorityQueue.size()) > maxUnigramCount) { + const int ptNodePos = priorityQueue.top().getDictPos(); + const PtNodeParams ptNodeParams = + ptNodeReader->fetchNodeInfoInBufferFromPtNodePos(ptNodePos); + if (!ptNodeWriter->markPtNodeAsWillBecomeNonTerminal(&ptNodeParams)) { + AKLOGE("Cannot mark PtNode as willBecomeNonterminal. PtNode pos: %d", ptNodePos); + return false; + } + priorityQueue.pop(); + } + return true; +} + bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds ::onVisitingPtNode(const PtNodeParams *const ptNodeParams) { if (!ptNodeParams->isTerminal()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h index 82877fdcc..26eb678b0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h @@ -25,6 +25,7 @@ namespace latinime { class HeaderPolicy; class Ver4DictBuffers; +class Ver4PatriciaTrieNodeReader; class Ver4PatriciaTrieNodeWriter; class Ver4PatriciaTrieWritingHelper { @@ -64,10 +65,56 @@ class Ver4PatriciaTrieWritingHelper { const TerminalPositionLookupTable::TerminalIdMap *const mTerminalIdMap; }; + // For truncateUnigrams(). + class DictProbability { + public: + DictProbability(const int dictPos, const int probability, const int timestamp) + : mDictPos(dictPos), mProbability(probability), mTimestamp(timestamp) {} + + int getDictPos() const { + return mDictPos; + } + + int getProbability() const { + return mProbability; + } + + int getTimestamp() const { + return mTimestamp; + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(DictProbability); + + int mDictPos; + int mProbability; + int mTimestamp; + }; + + // For truncateUnigrams(). + class DictProbabilityComparator { + public: + bool operator()(const DictProbability &left, const DictProbability &right) { + if (left.getProbability() != right.getProbability()) { + return left.getProbability() > right.getProbability(); + } + if (left.getTimestamp() != right.getTimestamp()) { + return left.getTimestamp() < right.getTimestamp(); + } + return left.getDictPos() > right.getDictPos(); + } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(DictProbabilityComparator); + }; + bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount, int *const outBigramCount); + bool turncateUnigrams(const Ver4PatriciaTrieNodeReader *const ptNodeReader, + Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount); + Ver4DictBuffers *const mBuffers; }; } // namespace latinime diff --git a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java index d77a11f01..825b8773c 100644 --- a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java +++ b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java @@ -303,6 +303,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(), 0 /* offset */, dictFile.length(), true /* useFullEditDistance */, Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */); + setCurrentTime(binaryDictionary, mCurrentTime); final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random); final ArrayList words = new ArrayList(); @@ -339,7 +340,65 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { forcePassingLongTime(binaryDictionary); assertEquals(0, Integer.parseInt(binaryDictionary.getPropertyForTests( BinaryDictionary.UNIGRAM_COUNT_QUERY))); + } + public void testOverflowUnigrams() { + testOverflowUnigrams(FormatSpec.VERSION4); + } + + private void testOverflowUnigrams(final int formatVersion) { + final int unigramCount = 20000; + final int eachUnigramTypedCount = 5; + final int strongUnigramTypedCount = 20; + final int weakUnigramTypedCount = 1; + final int codePointSetSize = 50; + final long seed = System.currentTimeMillis(); + final Random random = new Random(seed); + + File dictFile = null; + try { + dictFile = createEmptyDictionaryAndGetFile("TestBinaryDictionary", formatVersion); + } catch (IOException e) { + fail("IOException while writing an initial dictionary : " + e); + } + BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(), + 0 /* offset */, dictFile.length(), true /* useFullEditDistance */, + Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */); + setCurrentTime(binaryDictionary, mCurrentTime); + final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random); + + final String strong = "strong"; + final String weak = "weak"; + for (int j = 0; j < strongUnigramTypedCount; j++) { + addUnigramWord(binaryDictionary, strong, DUMMY_PROBABILITY); + } + for (int j = 0; j < weakUnigramTypedCount; j++) { + addUnigramWord(binaryDictionary, weak, DUMMY_PROBABILITY); + } + assertTrue(binaryDictionary.isValidWord(strong)); + assertTrue(binaryDictionary.isValidWord(weak)); + + for (int i = 0; i < unigramCount; i++) { + final String word = CodePointUtils.generateWord(random, codePointSet); + for (int j = 0; j < eachUnigramTypedCount; j++) { + addUnigramWord(binaryDictionary, word, DUMMY_PROBABILITY); + } + if (binaryDictionary.needsToRunGC(true /* mindsBlockByGC */)) { + final int unigramCountBeforeGC = + Integer.parseInt(binaryDictionary.getPropertyForTests( + BinaryDictionary.UNIGRAM_COUNT_QUERY)); + assertTrue(binaryDictionary.isValidWord(strong)); + assertTrue(binaryDictionary.isValidWord(weak)); + binaryDictionary.flushWithGC(); + final int unigramCountAfterGC = + Integer.parseInt(binaryDictionary.getPropertyForTests( + BinaryDictionary.UNIGRAM_COUNT_QUERY)); + assertTrue(unigramCountBeforeGC > unigramCountAfterGC); + assertFalse(binaryDictionary.isValidWord(weak)); + assertTrue(binaryDictionary.isValidWord(strong)); + break; + } + } } public void testAddManyBigramsToDecayingDict() { @@ -363,6 +422,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase { BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(), 0 /* offset */, dictFile.length(), true /* useFullEditDistance */, Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */); + setCurrentTime(binaryDictionary, mCurrentTime); final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random); final ArrayList words = new ArrayList();