diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v3/dynamic_patricia_trie_writing_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v3/dynamic_patricia_trie_writing_utils.h index 2c3b116b3..43e8ea649 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v3/dynamic_patricia_trie_writing_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v3/dynamic_patricia_trie_writing_utils.h @@ -39,6 +39,13 @@ class DynamicPatriciaTrieWritingUtils { static bool writePtNodeArraySizeAndAdvancePosition(BufferWithExtendableBuffer *const buffer, const size_t arraySize, int *const arraySizeFieldPos); + static bool writeFlags(BufferWithExtendableBuffer *const buffer, + const DynamicPatriciaTrieReadingUtils::NodeFlags nodeFlags, + const int nodeFlagsFieldPos) { + int writingPos = nodeFlagsFieldPos; + return writeFlagsAndAdvancePosition(buffer, nodeFlags, &writingPos); + } + static bool writeFlagsAndAdvancePosition(BufferWithExtendableBuffer *const buffer, const DynamicPatriciaTrieReadingUtils::NodeFlags nodeFlags, int *const nodeFlagsFieldPos); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 95921c580..d52cc4925 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -160,8 +160,21 @@ bool Ver4PatriciaTrieNodeWriter::addNewBigramEntry( const PtNodeParams *const sourcePtNodeParams, const PtNodeParams *const targetPtNodeParam, const int probability, const int timestamp, bool *const outAddedNewBigram) { - return mBigramPolicy->addNewEntry(sourcePtNodeParams->getTerminalId(), - targetPtNodeParam->getTerminalId(), probability, timestamp, outAddedNewBigram); + if (!mBigramPolicy->addNewEntry(sourcePtNodeParams->getTerminalId(), + targetPtNodeParam->getTerminalId(), probability, timestamp, outAddedNewBigram)) { + AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", + sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId()); + return false; + } + if (!sourcePtNodeParams->hasBigrams()) { + // Update has bigrams flag. + return updatePtNodeFlags(sourcePtNodeParams->getHeadPos(), + sourcePtNodeParams->isBlacklisted(), sourcePtNodeParams->isNotAWord(), + sourcePtNodeParams->isTerminal(), sourcePtNodeParams->hasShortcutTargets(), + true /* hasBigrams */, + sourcePtNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); + } + return true; } bool Ver4PatriciaTrieNodeWriter::removeBigramEntry( @@ -220,8 +233,31 @@ bool Ver4PatriciaTrieNodeWriter::updateAllPositionFields( bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptNodeParams, const int *const targetCodePoints, const int targetCodePointCount, const int shortcutProbability) { - return mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(), - targetCodePoints, targetCodePointCount, shortcutProbability); + if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(), + targetCodePoints, targetCodePointCount, shortcutProbability)) { + AKLOGE("Cannot add new shortuct entry. terminalId: %d", ptNodeParams->getTerminalId()); + return false; + } + if (!ptNodeParams->hasShortcutTargets()) { + // Update has shortcut targets flag. + return updatePtNodeFlags(ptNodeParams->getHeadPos(), + ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(), + ptNodeParams->isTerminal(), true /* hasShortcutTargets */, + ptNodeParams->hasBigrams(), + ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); + } + return true; +} + +bool Ver4PatriciaTrieNodeWriter::updatePtNodeHasBigramsAndShortcutTargetsFlags( + const PtNodeParams *const ptNodeParams) { + const bool hasBigrams = mBuffers->getBigramDictContent()->getBigramListHeadPos( + ptNodeParams->getTerminalId()) != NOT_A_DICT_POS; + const bool hasShortcutTargets = mBuffers->getShortcutDictContent()->getShortcutListHeadPos( + ptNodeParams->getTerminalId()) != NOT_A_DICT_POS; + return updatePtNodeFlags(ptNodeParams->getHeadPos(), ptNodeParams->isBlacklisted(), + ptNodeParams->isNotAWord(), ptNodeParams->isTerminal(), hasShortcutTargets, + hasBigrams, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( @@ -273,19 +309,9 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( ptNodeParams->getChildrenPos(), ptNodeWritingPos)) { return false; } - // Create node flags and write them. - PatriciaTrieReadingUtils::NodeFlags nodeFlags = - PatriciaTrieReadingUtils::createAndGetFlags(ptNodeParams->isBlacklisted(), - ptNodeParams->isNotAWord(), isTerminal, - ptNodeParams->hasShortcutTargets(), ptNodeParams->hasBigrams(), - ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */, - CHILDREN_POSITION_FIELD_SIZE); - int flagsFieldPos = nodePos; - if (!DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, nodeFlags, - &flagsFieldPos)) { - return false; - } - return true; + return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(), + isTerminal, ptNodeParams->hasShortcutTargets(), ptNodeParams->hasBigrams(), + ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( @@ -304,4 +330,19 @@ const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( } } +bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, + const bool isBlacklisted, const bool isNotAWord, const bool isTerminal, + const bool hasShortcutTargets, const bool hasBigrams, const bool hasMultipleChars) { + // Create node flags and write them. + PatriciaTrieReadingUtils::NodeFlags nodeFlags = + PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, isTerminal, + hasShortcutTargets, hasBigrams, hasMultipleChars, + CHILDREN_POSITION_FIELD_SIZE); + if (!DynamicPatriciaTrieWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) { + AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos); + return false; + } + return true; +} + } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h index 4a2a79259..73b2ae309 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h @@ -87,6 +87,8 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const int *const targetCodePoints, const int targetCodePointCount, const int shortcutProbability); + bool updatePtNodeHasBigramsAndShortcutTargetsFlags(const PtNodeParams *const ptNodeParams); + private: DISALLOW_COPY_AND_ASSIGN(Ver4PatriciaTrieNodeWriter); @@ -100,6 +102,10 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const ProbabilityEntry *const originalProbabilityEntry, const int newProbability, const int timestamp) const; + bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord, + const bool isTerminal, const bool hasShortcutTargets, const bool hasBigrams, + const bool hasMultipleChars); + static const int CHILDREN_POSITION_FIELD_SIZE; BufferWithExtendableBuffer *const mTrieBuffer; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 2a80784f8..6ab460ff6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -172,18 +172,18 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return false; } newDictReadingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); - TraversePolicyToUpdateAllTerminalIds traversePolicyToUpdateAllTerminalIds(&newPtNodeWriter, - &terminalIdMap); + TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds + traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds(&newPtNodeWriter, &terminalIdMap); if (!newDictReadingHelper.traverseAllPtNodesInPostorderDepthFirstManner( - &traversePolicyToUpdateAllTerminalIds)) { + &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) { return false; } *outUnigramCount = traversePolicyToUpdateAllPositionFields.getUnigramCount(); return true; } -bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllTerminalIds::onVisitingPtNode( - const PtNodeParams *const ptNodeParams) { +bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds + ::onVisitingPtNode(const PtNodeParams *const ptNodeParams) { if (!ptNodeParams->isTerminal()) { return true; } @@ -194,7 +194,10 @@ bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllTerminalIds::onVisi ptNodeParams->getTerminalId(), mTerminalIdMap->size()); return false; } - return mPtNodeWriter->updateTerminalId(ptNodeParams, it->second); + if (!mPtNodeWriter->updateTerminalId(ptNodeParams, it->second)) { + AKLOGE("Cannot update terminal id. %d -> %d", it->first, it->second); + } + return mPtNodeWriter->updatePtNodeHasBigramsAndShortcutTargetsFlags(ptNodeParams); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h index 9344bde39..505012b4f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h @@ -41,12 +41,13 @@ class Ver4PatriciaTrieWritingHelper { private: DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTrieWritingHelper); - class TraversePolicyToUpdateAllTerminalIds + class TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds : public DynamicPatriciaTrieReadingHelper::TraversingEventListener { public: - TraversePolicyToUpdateAllTerminalIds(Ver4PatriciaTrieNodeWriter *const ptNodeWriter, + TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds( + Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap) - : mPtNodeWriter(ptNodeWriter), mTerminalIdMap(terminalIdMap) {}; + : mPtNodeWriter(ptNodeWriter), mTerminalIdMap(terminalIdMap) {} bool onAscend() { return true; } @@ -57,7 +58,7 @@ class Ver4PatriciaTrieWritingHelper { bool onVisitingPtNode(const PtNodeParams *const ptNodeParams); private: - DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToUpdateAllTerminalIds); + DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds); Ver4PatriciaTrieNodeWriter *const mPtNodeWriter; const TerminalPositionLookupTable::TerminalIdMap *const mTerminalIdMap;