diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 696be0aeb..808d2a6cd 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -112,10 +112,10 @@ class DicNode { mIsUsed = true; mIsCachedForNextSuggestion = false; mDicNodeProperties.init( - NOT_A_VALID_WORD_POS /* pos */, rootGroupPos, NOT_A_DICT_POS /* attributesPos */, - NOT_A_CODE_POINT /* nodeCodePoint */, NOT_A_PROBABILITY /* probability */, - false /* isTerminal */, true /* hasChildren */, - false /* isBlacklistedOrNotAWord */, 0 /* depth */, 0 /* terminalDepth */); + NOT_A_VALID_WORD_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, + NOT_A_PROBABILITY /* probability */, false /* isTerminal */, + true /* hasChildren */, false /* isBlacklistedOrNotAWord */, 0 /* depth */, + 0 /* terminalDepth */); mDicNodeState.init(prevWordNodePos); PROF_NODE_RESET(mProfiler); } @@ -125,10 +125,10 @@ class DicNode { mIsUsed = true; mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; mDicNodeProperties.init( - NOT_A_VALID_WORD_POS /* pos */, rootGroupPos, NOT_A_DICT_POS /* attributesPos */, - NOT_A_CODE_POINT /* nodeCodePoint */, NOT_A_PROBABILITY /* probability */, - false /* isTerminal */, true /* hasChildren */, - false /* isBlacklistedOrNotAWord */, 0 /* depth */, 0 /* terminalDepth */); + NOT_A_VALID_WORD_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, + NOT_A_PROBABILITY /* probability */, false /* isTerminal */, + true /* hasChildren */, false /* isBlacklistedOrNotAWord */, 0 /* depth */, + 0 /* terminalDepth */); // TODO: Move to dicNodeState? mDicNodeState.mDicNodeStateOutput.init(); // reset for next word mDicNodeState.mDicNodeStateInput.init( @@ -157,18 +157,16 @@ class DicNode { PROF_NODE_COPY(&parentNode->mProfiler, mProfiler); } - void initAsChild(DicNode *dicNode, const int pos, const int childrenPos, - const int attributesPos, const int probability, const bool isTerminal, - const bool hasChildren, const bool isBlacklistedOrNotAWord, + void initAsChild(DicNode *dicNode, const int pos, const int childrenPos, const int probability, + const bool isTerminal, const bool hasChildren, const bool isBlacklistedOrNotAWord, const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { mIsUsed = true; uint16_t newDepth = static_cast(dicNode->getNodeCodePointCount() + 1); mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; const uint16_t newLeavingDepth = static_cast( dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); - mDicNodeProperties.init(pos, childrenPos, attributesPos, mergedNodeCodePoints[0], - probability, isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth, - newLeavingDepth); + mDicNodeProperties.init(pos, childrenPos, mergedNodeCodePoints[0], probability, + isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth, newLeavingDepth); mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, mergedNodeCodePoints); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); @@ -467,10 +465,6 @@ class DicNode { return mDicNodeProperties.isBlacklistedOrNotAWord(); } - int getAttributesPos() const { - return mDicNodeProperties.getAttributesPos(); - } - inline uint16_t getNodeCodePointCount() const { return mDicNodeProperties.getDepth(); } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/dic_node_properties.h index d98000d83..9e0f62ceb 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_properties.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_properties.h @@ -31,20 +31,17 @@ namespace latinime { class DicNodeProperties { public: AK_FORCE_INLINE DicNodeProperties() - : mPos(0), mChildrenPos(0), mAttributesPos(0), mProbability(0), - mNodeCodePoint(0), mIsTerminal(false), mHasChildren(false), - mIsBlacklistedOrNotAWord(false), mDepth(0), mLeavingDepth(0) {} + : mPos(0), mChildrenPos(0), mProbability(0), mNodeCodePoint(0), mIsTerminal(false), + mHasChildren(false), mIsBlacklistedOrNotAWord(false), mDepth(0), mLeavingDepth(0) {} virtual ~DicNodeProperties() {} // Should be called only once per DicNode is initialized. - void init(const int pos, const int childrenPos, const int attributesPos, - const int nodeCodePoint, const int probability, const bool isTerminal, - const bool hasChildren, const bool isBlacklistedOrNotAWord, + void init(const int pos, const int childrenPos, const int nodeCodePoint, const int probability, + const bool isTerminal, const bool hasChildren, const bool isBlacklistedOrNotAWord, const uint16_t depth, const uint16_t leavingDepth) { mPos = pos; mChildrenPos = childrenPos; - mAttributesPos = attributesPos; mNodeCodePoint = nodeCodePoint; mProbability = probability; mIsTerminal = isTerminal; @@ -58,7 +55,6 @@ class DicNodeProperties { void init(const DicNodeProperties *const nodeProp) { mPos = nodeProp->mPos; mChildrenPos = nodeProp->mChildrenPos; - mAttributesPos = nodeProp->mAttributesPos; mNodeCodePoint = nodeProp->mNodeCodePoint; mProbability = nodeProp->mProbability; mIsTerminal = nodeProp->mIsTerminal; @@ -72,7 +68,6 @@ class DicNodeProperties { void init(const DicNodeProperties *const nodeProp, const int codePoint) { mPos = nodeProp->mPos; mChildrenPos = nodeProp->mChildrenPos; - mAttributesPos = nodeProp->mAttributesPos; mNodeCodePoint = codePoint; // Overwrite the node char of a passing child mProbability = nodeProp->mProbability; mIsTerminal = nodeProp->mIsTerminal; @@ -90,10 +85,6 @@ class DicNodeProperties { return mChildrenPos; } - int getAttributesPos() const { - return mAttributesPos; - } - int getProbability() const { return mProbability; } @@ -129,7 +120,6 @@ class DicNodeProperties { // for this class int mPos; int mChildrenPos; - int mAttributesPos; int mProbability; int mNodeCodePoint; bool mIsTerminal; diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 67fbc1a38..8b6f45599 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -77,7 +77,6 @@ namespace latinime { const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags)); const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); - const bool hasShortcuts = (0 != (BinaryFormat::FLAG_HAS_SHORTCUT_TARGETS & flags)); const bool isBlacklistedOrNotAWord = BinaryFormat::hasBlacklistedOrNotAWordFlag(flags); int codePoint = BinaryFormat::getCodePointAndForwardPointer( @@ -104,17 +103,14 @@ namespace latinime { pos = BinaryFormat::skipProbability(flags, pos); int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition( binaryDictionaryInfo->getDictRoot(), flags, pos) : NOT_A_DICT_POS; - const int attributesPos = - hasShortcuts ? BinaryFormat::skipChildrenPosition(flags, pos) : NOT_A_DICT_POS; const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes( binaryDictionaryInfo->getDictRoot(), flags, pos); if (childrenFilter->isFilteredOut(mergedNodeCodePoints[0])) { return siblingPos; } - childDicNodes->pushLeavingChild(dicNode, nextPos, childrenPos, attributesPos, - probability, isTerminal, hasChildren, isBlacklistedOrNotAWord, - mergedNodeCodePointCount, mergedNodeCodePoints); + childDicNodes->pushLeavingChild(dicNode, nextPos, childrenPos, probability, isTerminal, + hasChildren, isBlacklistedOrNotAWord, mergedNodeCodePointCount, mergedNodeCodePoints); return siblingPos; } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_vector.h b/native/jni/src/suggest/core/dicnode/dic_node_vector.h index 5ac4eeaf4..2ba4e5e95 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_vector.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_vector.h @@ -63,13 +63,13 @@ class DicNodeVector { } void pushLeavingChild(DicNode *dicNode, const int pos, const int childrenPos, - const int attributesPos, const int probability, const bool isTerminal, - const bool hasChildren, const bool isBlacklistedOrNotAWord, - const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { + const int probability, const bool isTerminal, const bool hasChildren, + const bool isBlacklistedOrNotAWord, const uint16_t mergedNodeCodePointCount, + const int *const mergedNodeCodePoints) { ASSERT(!mLock); mDicNodes.push_back(mEmptyNode); - mDicNodes.back().initAsChild(dicNode, pos, childrenPos, attributesPos, probability, - isTerminal, hasChildren, isBlacklistedOrNotAWord, mergedNodeCodePointCount, + mDicNodes.back().initAsChild(dicNode, pos, childrenPos, probability, isTerminal, + hasChildren, isBlacklistedOrNotAWord, mergedNodeCodePointCount, mergedNodeCodePoints); } diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp index 748430233..618a9d2d5 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp @@ -109,13 +109,13 @@ int BigramDictionary::getPredictions(const int *prevWord, int prevWordLength, in int pos = getBigramListPositionForWord(prevWord, prevWordLength, false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams - if (0 == pos) { + if (NOT_A_DICT_POS == pos) { // If no bigrams for this exact word, search again in lower case. pos = getBigramListPositionForWord(prevWord, prevWordLength, true /* forceLowerCaseSearch */); } // If still no bigrams, we really don't have them! - if (0 == pos) return 0; + if (NOT_A_DICT_POS == pos) return 0; int bigramCount = 0; int unigramProbability = 0; @@ -154,8 +154,8 @@ int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const in int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch); if (NOT_A_VALID_WORD_POS == pos) return 0; - return BinaryFormat::getBigramListPositionForWordPosition( - mBinaryDictionaryInfo->getDictRoot(), pos); + return mBinaryDictionaryInfo->getStructurePolicy()->getBigramsPositionOfNode( + mBinaryDictionaryInfo, pos); } bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const { @@ -178,7 +178,7 @@ bool BigramDictionary::isValidBigram(const int *word0, int length0, const int *w int length1) const { int pos = getBigramListPositionForWord(word0, length0, false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams - if (0 == pos) return false; + if (NOT_A_DICT_POS == pos) return false; int nextWordPos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( mBinaryDictionaryInfo, word1, length1, false /* forceLowerCaseSearch */); if (NOT_A_VALID_WORD_POS == nextWordPos) return false; diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h index f2b48e960..8cbb12998 100644 --- a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h @@ -28,7 +28,7 @@ class BinaryDictionaryBigramsIterator { BinaryDictionaryBigramsIterator( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int pos) : mBinaryDictionaryInfo(binaryDictionaryInfo), mPos(pos), mBigramFlags(0), - mBigramPos(0), mHasNext(true) {} + mBigramPos(NOT_A_DICT_POS), mHasNext(pos != NOT_A_DICT_POS) {} AK_FORCE_INLINE bool hasNext() const { return mHasNext; diff --git a/native/jni/src/suggest/core/dictionary/binary_format.h b/native/jni/src/suggest/core/dictionary/binary_format.h index 6a5afd12e..9e22b50cd 100644 --- a/native/jni/src/suggest/core/dictionary/binary_format.h +++ b/native/jni/src/suggest/core/dictionary/binary_format.h @@ -73,8 +73,11 @@ class BinaryFormat { const int length, const bool forceLowerCaseSearch); static int getCodePointsAndProbabilityAndReturnCodePointCount( const uint8_t *const root, const int nodePos, const int maxCodePointCount, - int *outCodePoints, int *outUnigramProbability); - static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); + int *const outCodePoints, int *const outUnigramProbability); + static int getBigramListPositionForWordPosition(const uint8_t *const root, + const int nodePosition); + static int getShortcutListPositionForWordPosition(const uint8_t *const root, + const int nodePosition); private: DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat); @@ -344,8 +347,8 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, * Return value : the length of the word, of 0 if the word was not found. */ AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount( - const uint8_t *const root, const int nodePos, - const int maxCodePointCount, int *outCodePoints, int *outUnigramProbability) { + const uint8_t *const root, const int nodePos, const int maxCodePointCount, + int *const outCodePoints, int *const outUnigramProbability) { int pos = 0; int wordPos = 0; @@ -473,10 +476,11 @@ AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointC } AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition( - const uint8_t *const root, int position) { - if (NOT_A_VALID_WORD_POS == position) return 0; + const uint8_t *const root, const int nodePosition) { + if (NOT_A_VALID_WORD_POS == nodePosition) return NOT_A_DICT_POS; + int position = nodePosition; const uint8_t flags = getFlagsAndForwardPointer(root, &position); - if (!(flags & FLAG_HAS_BIGRAMS)) return 0; + if (!(flags & FLAG_HAS_BIGRAMS)) return NOT_A_DICT_POS; if (flags & FLAG_HAS_MULTIPLE_CHARS) { position = skipOtherCharacters(root, position); } else { @@ -488,5 +492,21 @@ AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition( return position; } +AK_FORCE_INLINE int BinaryFormat::getShortcutListPositionForWordPosition( + const uint8_t *const root, const int nodePosition) { + if (NOT_A_VALID_WORD_POS == nodePosition) return NOT_A_DICT_POS; + int position = nodePosition; + const uint8_t flags = getFlagsAndForwardPointer(root, &position); + if (!(flags & FLAG_HAS_SHORTCUT_TARGETS)) return NOT_A_DICT_POS; + if (flags & FLAG_HAS_MULTIPLE_CHARS) { + position = skipOtherCharacters(root, position); + } else { + getCodePointAndForwardPointer(root, &position); + } + position = skipProbability(flags, position); + position = skipChildrenPosition(flags, position); + return position; +} + } // namespace latinime #endif // LATINIME_BINARY_FORMAT_H diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index 60169f80e..12f1d08b9 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -22,8 +22,8 @@ #include "defines.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/binary_dictionary_info.h" -#include "suggest/core/dictionary/binary_format.h" #include "suggest/core/dictionary/bloom_filter.h" +#include "suggest/core/dictionary/probability_utils.h" #include "utils/hash_map_compat.h" namespace latinime { @@ -67,11 +67,8 @@ class MultiBigramMap { ~BigramMap() {} void init(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) { - const int bigramsListPos = BinaryFormat::getBigramListPositionForWordPosition( - binaryDictionaryInfo->getDictRoot(), nodePos); - if (0 == bigramsListPos) { - return; - } + const int bigramsListPos = binaryDictionaryInfo->getStructurePolicy()-> + getBigramsPositionOfNode(binaryDictionaryInfo, nodePos); for (BinaryDictionaryBigramsIterator bigramsIt(binaryDictionaryInfo, bigramsListPos); bigramsIt.hasNext(); /* no-op */) { bigramsIt.next(); @@ -110,11 +107,8 @@ class MultiBigramMap { AK_FORCE_INLINE int readBigramProbabilityFromBinaryDictionary( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos, const int nextWordPosition, const int unigramProbability) { - const int bigramsListPos = BinaryFormat::getBigramListPositionForWordPosition( - binaryDictionaryInfo->getDictRoot(), nodePos); - if (0 == bigramsListPos) { - return ProbabilityUtils::backoff(unigramProbability); - } + const int bigramsListPos = binaryDictionaryInfo->getStructurePolicy()-> + getBigramsPositionOfNode(binaryDictionaryInfo, nodePos); for (BinaryDictionaryBigramsIterator bigramsIt(binaryDictionaryInfo, bigramsListPos); bigramsIt.hasNext(); /* no-op */) { bigramsIt.next(); diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_policy.h index 48ba5b8c2..cc14c982c 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_policy.h @@ -62,6 +62,12 @@ class DictionaryStructurePolicy { virtual int getUnigramProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) const = 0; + virtual int getShortcutPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const = 0; + + virtual int getBigramsPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const = 0; + protected: DictionaryStructurePolicy() {} virtual ~DictionaryStructurePolicy() {} diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index c6da6f003..d6383b958 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -210,14 +210,16 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen } if (!terminalDicNode->hasMultipleWords()) { + const BinaryDictionaryInfo *const binaryDictionaryInfo = + traverseSession->getBinaryDictionaryInfo(); const TerminalAttributes terminalAttributes(traverseSession->getBinaryDictionaryInfo(), - terminalDicNode->getAttributesPos()); + binaryDictionaryInfo->getStructurePolicy()->getShortcutPositionOfNode( + binaryDictionaryInfo, terminalDicNode->getPos())); // Shortcut is not supported for multiple words suggestions. // TODO: Check shortcuts during traversal for multiple words suggestions. const bool sameAsTyped = TRAVERSAL->sameAsTyped(traverseSession, terminalDicNode); outputWordIndex = ShortcutUtils::outputShortcuts(&terminalAttributes, outputWordIndex, finalScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); - } DicNode::managedDelete(terminalDicNode); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp index c807fb7c9..24de9dcd9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp @@ -69,4 +69,18 @@ int PatriciaTriePolicy::getUnigramProbability( return BinaryFormat::readProbabilityWithoutMovingPointer(root, pos); } +int PatriciaTriePolicy::getShortcutPositionOfNode( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const { + return BinaryFormat::getShortcutListPositionForWordPosition( + binaryDictionaryInfo->getDictRoot(), nodePos); +} + +int PatriciaTriePolicy::getBigramsPositionOfNode( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const { + return BinaryFormat::getBigramListPositionForWordPosition( + binaryDictionaryInfo->getDictRoot(), nodePos); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h index 0a16e414a..8f36fe00e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h @@ -48,6 +48,12 @@ class PatriciaTriePolicy : public DictionaryStructurePolicy { int getUnigramProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) const; + int getShortcutPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const; + + int getBigramsPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const; + private: DISALLOW_COPY_AND_ASSIGN(PatriciaTriePolicy); static const PatriciaTriePolicy sInstance;