diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 214cdfca6..e44d5ae20 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -103,10 +103,10 @@ class DicNode { PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } - // Init for root with prevWordsPtNodePos which is used for n-gram - void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordsPtNodePos) { + // Init for root with prevWordIds which is used for n-gram + void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordIds) { mIsCachedForNextSuggestion = false; - mDicNodeProperties.init(rootPtNodeArrayPos, prevWordsPtNodePos); + mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds); mDicNodeState.init(); PROF_NODE_RESET(mProfiler); } @@ -114,12 +114,12 @@ class DicNode { // Init for root with previous word void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) { mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; - int newPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - newPrevWordsPtNodePos[0] = dicNode->mDicNodeProperties.getPtNodePos(); - for (size_t i = 1; i < NELEMS(newPrevWordsPtNodePos); ++i) { - newPrevWordsPtNodePos[i] = dicNode->getPrevWordsTerminalPtNodePos()[i - 1]; + int newPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId(); + for (size_t i = 1; i < NELEMS(newPrevWordIds); ++i) { + newPrevWordIds[i] = dicNode->getPrevWordIds()[i - 1]; } - mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordsPtNodePos); + mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordIds); mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, dicNode->mDicNodeProperties.getDepth()); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); @@ -145,7 +145,7 @@ class DicNode { dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0], probability, wordId, hasChildren, isBlacklistedOrNotAWord, newDepth, - newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordsTerminalPtNodePos()); + newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds()); mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, mergedNodeCodePoints); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); @@ -204,13 +204,18 @@ class DicNode { } // Used to get n-gram probability in DicNodeUtils. + int getWordId() const { + return mDicNodeProperties.getWordId(); + } + + // TODO: Remove int getPtNodePos() const { return mDicNodeProperties.getPtNodePos(); } - // TODO: Use view class to return PtNodePos array. - const int *getPrevWordsTerminalPtNodePos() const { - return mDicNodeProperties.getPrevWordsTerminalPtNodePos(); + // TODO: Use view class to return word id array. + const int *getPrevWordIds() const { + return mDicNodeProperties.getPrevWordIds(); } // Used in DicNodeUtils diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 69ea67418..87d245276 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -29,8 +29,8 @@ namespace latinime { /* static */ void DicNodeUtils::initAsRoot( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const prevWordsPtNodePos, DicNode *const newRootDicNode) { - newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordsPtNodePos); + const int *const prevWordIds, DicNode *const newRootDicNode) { + newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds); } /*static */ void DicNodeUtils::initAsRootWithPreviousWord( @@ -86,9 +86,9 @@ namespace latinime { const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) { const int unigramProbability = dicNode->getProbability(); if (multiBigramMap) { - const int *const prevWordsPtNodePos = dicNode->getPrevWordsTerminalPtNodePos(); + const int *const prevWordIds = dicNode->getPrevWordIds(); return multiBigramMap->getBigramProbability(dictionaryStructurePolicy, - prevWordsPtNodePos, dicNode->getPtNodePos(), unigramProbability); + prevWordIds, dicNode->getWordId(), unigramProbability); } return dictionaryStructurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY); diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h index 00e80c604..56ff6e3d0 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -30,7 +30,7 @@ class DicNodeUtils { public: static void initAsRoot( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const prevWordPtNodePos, DicNode *const newRootDicNode); + const int *const prevWordIds, DicNode *const newRootDicNode); static void initAsRootWithPreviousWord( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode); diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h index fc242a92b..1d905b9fe 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h @@ -39,7 +39,7 @@ class DicNodeProperties { // Should be called only once per DicNode is initialized. void init(const int pos, const int childrenPos, const int nodeCodePoint, const int probability, const int wordId, const bool hasChildren, const bool isBlacklistedOrNotAWord, - const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordsNodePos) { + const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordIds) { mPtNodePos = pos; mChildrenPtNodeArrayPos = childrenPos; mDicNodeCodePoint = nodeCodePoint; @@ -49,11 +49,11 @@ class DicNodeProperties { mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord; mDepth = depth; mLeavingDepth = leavingDepth; - memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos)); + memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds)); } // Init for root with prevWordsPtNodePos which is used for n-gram - void init(const int rootPtNodeArrayPos, const int *const prevWordsNodePos) { + void init(const int rootPtNodeArrayPos, const int *const prevWordIds) { mPtNodePos = NOT_A_DICT_POS; mChildrenPtNodeArrayPos = rootPtNodeArrayPos; mDicNodeCodePoint = NOT_A_CODE_POINT; @@ -63,7 +63,7 @@ class DicNodeProperties { mIsBlacklistedOrNotAWord = false; mDepth = 0; mLeavingDepth = 0; - memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos)); + memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds)); } void initByCopy(const DicNodeProperties *const dicNodeProp) { @@ -76,8 +76,7 @@ class DicNodeProperties { mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord; mDepth = dicNodeProp->mDepth; mLeavingDepth = dicNodeProp->mLeavingDepth; - memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos, - sizeof(mPrevWordsTerminalPtNodePos)); + memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds)); } // Init as passing child @@ -91,8 +90,7 @@ class DicNodeProperties { mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord; mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child mLeavingDepth = dicNodeProp->mLeavingDepth; - memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos, - sizeof(mPrevWordsTerminalPtNodePos)); + memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds)); } int getPtNodePos() const { @@ -132,8 +130,12 @@ class DicNodeProperties { return mIsBlacklistedOrNotAWord; } - const int *getPrevWordsTerminalPtNodePos() const { - return mPrevWordsTerminalPtNodePos; + const int *getPrevWordIds() const { + return mPrevWordIds; + } + + int getWordId() const { + return mWordId; } private: @@ -149,7 +151,7 @@ class DicNodeProperties { bool mIsBlacklistedOrNotAWord; uint16_t mDepth; uint16_t mLeavingDepth; - int mPrevWordsTerminalPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; }; } // namespace latinime #endif // LATINIME_DIC_NODE_PROPERTIES_H diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index c025bfcf5..956243161 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -93,11 +93,10 @@ void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, TimeKeeper::setCurrentTime(); NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults, mDictionaryStructureWithBufferPolicy.get()); - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos( - mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds, true /* tryLowerCaseSearch */); - mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordsPtNodePos, &listener); + mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener); } int Dictionary::getProbability(const int *word, int length) const { @@ -113,18 +112,17 @@ int Dictionary::getMaxProbabilityOfExactMatches(const int *word, int length) con int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word, int length) const { TimeKeeper::setCurrentTime(); - int nextWordPos = mDictionaryStructureWithBufferPolicy->getTerminalPtNodePositionOfWord( + int wordId = mDictionaryStructureWithBufferPolicy->getWordId( CodePointArrayView(word, length), false /* forceLowerCaseSearch */); - if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY; + if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY; if (!prevWordsInfo) { - return getDictionaryStructurePolicy()->getProbabilityOfPtNode( - nullptr /* prevWordsPtNodePos */, nextWordPos); + return getDictionaryStructurePolicy()->getProbabilityOfWord( + nullptr /* prevWordsPtNodePos */, wordId); } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos( - mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds, true /* tryLowerCaseSearch */); - return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsPtNodePos, nextWordPos); + return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId); } bool Dictionary::addUnigramEntry(const int *const word, const int length, diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp index b94966cbe..b372b6b4f 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp @@ -34,11 +34,11 @@ namespace latinime { // No prev words information. PrevWordsInfo emptyPrevWordsInfo; - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - emptyPrevWordsInfo.getPrevWordsTerminalPtNodePos(dictionaryStructurePolicy, - prevWordsPtNodePos, false /* tryLowerCaseSearch */); + int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds, + false /* tryLowerCaseSearch */); current.emplace_back(); - DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordsPtNodePos, ¤t.front()); + DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, ¤t.front()); for (int i = 0; i < codePointCount; ++i) { // The base-lower input is used to ignore case errors and accent errors. const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]); diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp index 91f33a8dd..979d61edb 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp @@ -35,39 +35,37 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP = // Also caches the bigrams if there is space remaining and they have not been cached already. int MultiBigramMap::getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, + const int *const prevWordIds, const int nextWordId, const int unigramProbability) { - if (!prevWordsPtNodePos || prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (!prevWordIds || prevWordIds[0] == NOT_A_WORD_ID) { return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY); } - std::unordered_map::const_iterator mapPosition = - mBigramMaps.find(prevWordsPtNodePos[0]); + const auto mapPosition = mBigramMaps.find(prevWordIds[0]); if (mapPosition != mBigramMaps.end()) { - return mapPosition->second.getBigramProbability(structurePolicy, nextWordPosition, + return mapPosition->second.getBigramProbability(structurePolicy, nextWordId, unigramProbability); } if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { - addBigramsForWordPosition(structurePolicy, prevWordsPtNodePos); - return mBigramMaps[prevWordsPtNodePos[0]].getBigramProbability(structurePolicy, - nextWordPosition, unigramProbability); + addBigramsForWord(structurePolicy, prevWordIds); + return mBigramMaps[prevWordIds[0]].getBigramProbability(structurePolicy, + nextWordId, unigramProbability); } - return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordsPtNodePos, - nextWordPosition, unigramProbability); + return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordIds, + nextWordId, unigramProbability); } void MultiBigramMap::BigramMap::init( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos) { - structurePolicy->iterateNgramEntries(prevWordsPtNodePos, this /* listener */); + const int *const prevWordIds) { + structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */); } int MultiBigramMap::BigramMap::getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int nextWordPosition, const int unigramProbability) const { + const int nextWordId, const int unigramProbability) const { int bigramProbability = NOT_A_PROBABILITY; - if (mBloomFilter.isInFilter(nextWordPosition)) { - const std::unordered_map::const_iterator bigramProbabilityIt = - mBigramMap.find(nextWordPosition); + if (mBloomFilter.isInFilter(nextWordId)) { + const auto bigramProbabilityIt = mBigramMap.find(nextWordId); if (bigramProbabilityIt != mBigramMap.end()) { bigramProbability = bigramProbabilityIt->second; } @@ -75,29 +73,27 @@ int MultiBigramMap::BigramMap::getBigramProbability( return structurePolicy->getProbability(unigramProbability, bigramProbability); } -void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, - const int targetPtNodePos) { - if (targetPtNodePos == NOT_A_DICT_POS) { +void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, const int targetWordId) { + if (targetWordId == NOT_A_WORD_ID) { return; } - mBigramMap[targetPtNodePos] = ngramProbability; - mBloomFilter.setInFilter(targetPtNodePos); + mBigramMap[targetWordId] = ngramProbability; + mBloomFilter.setInFilter(targetWordId); } -void MultiBigramMap::addBigramsForWordPosition( +void MultiBigramMap::addBigramsForWord( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos) { - if (prevWordsPtNodePos) { - mBigramMaps[prevWordsPtNodePos[0]].init(structurePolicy, prevWordsPtNodePos); + const int *const prevWordIds) { + if (prevWordIds) { + mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds); } } int MultiBigramMap::readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, + const int *const prevWordIds, const int nextWordId, const int unigramProbability) { - const int bigramProbability = structurePolicy->getProbabilityOfPtNode(prevWordsPtNodePos, - nextWordPosition); + const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId); if (bigramProbability != NOT_A_PROBABILITY) { return bigramProbability; } diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index ad36dde83..a8c4ded57 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -39,8 +39,7 @@ class MultiBigramMap { // Look up the bigram probability for the given word pair from the cached bigram maps. // Also caches the bigrams if there is space remaining and they have not been cached already. int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, - const int unigramProbability); + const int *const prevWordIds, const int nextWordId, const int unigramProbability); void clear() { mBigramMaps.clear(); @@ -58,11 +57,11 @@ class MultiBigramMap { virtual ~BigramMap() {} void init(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos); + const int *const prevWordIds); int getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int nextWordPosition, const int unigramProbability) const; - virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos); + const int nextWordId, const int unigramProbability) const; + virtual void onVisitEntry(const int ngramProbability, const int targetWordId); private: static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP; @@ -70,14 +69,12 @@ class MultiBigramMap { BloomFilter mBloomFilter; }; - void addBigramsForWordPosition( - const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos); + void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy, + const int *const prevWordIds); int readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, - const int unigramProbability); + const int *const prevWordIds, const int nextWordId, const int unigramProbability); static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; std::unordered_map mBigramMaps; diff --git a/native/jni/src/suggest/core/dictionary/ngram_listener.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h index 88b88bafb..e9b3c1aaf 100644 --- a/native/jni/src/suggest/core/dictionary/ngram_listener.h +++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h @@ -26,7 +26,7 @@ namespace latinime { */ class NgramListener { public: - virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos) = 0; + virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0; virtual ~NgramListener() {}; protected: diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index 0faf00003..72ec13fe8 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -53,15 +53,14 @@ class DictionaryStructureWithBufferPolicy { const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const = 0; - virtual int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, + virtual int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const = 0; virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0; - virtual int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const = 0; + virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0; - virtual void iterateNgramEntries(const int *const prevWordsPtNodePos, + virtual void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const = 0; virtual int getShortcutPositionOfPtNode(const int ptNodePos) const = 0; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index f1e411f38..d4d4d1eed 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -35,8 +35,8 @@ void DicTraverseSession::init(const Dictionary *const dictionary, mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy() ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; - prevWordsInfo->getPrevWordsTerminalPtNodePos( - getDictionaryStructurePolicy(), mPrevWordsPtNodePos, true /* tryLowerCaseSearch */); + prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds, + true /* tryLowerCaseSearch */); } void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 5a51a112d..0e676d897 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -55,8 +55,8 @@ class DicTraverseSession { mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. - for (size_t i = 0; i < NELEMS(mPrevWordsPtNodePos); ++i) { - mPrevWordsPtNodePos[i] = NOT_A_DICT_POS; + for (size_t i = 0; i < NELEMS(mPrevWordsIds); ++i) { + mPrevWordsIds[i] = NOT_A_DICT_POS; } } @@ -79,7 +79,7 @@ class DicTraverseSession { //-------------------- const ProximityInfo *getProximityInfo() const { return mProximityInfo; } const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; } - const int *getPrevWordsPtNodePos() const { return mPrevWordsPtNodePos; } + const int *getPrevWordIds() const { return mPrevWordsIds; } DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; } const ProximityInfoState *getProximityInfoState(int id) const { @@ -166,7 +166,7 @@ class DicTraverseSession { const int *const inputYs, const int *const times, const int *const pointerIds, const int inputSize, const float maxSpatialDistance, const int maxPointerCount); - int mPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + int mPrevWordsIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; const ProximityInfo *mProximityInfo; const Dictionary *mDictionary; const SuggestOptions *mSuggestOptions; diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h index 9b3a7d468..fc9a35968 100644 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ b/native/jni/src/suggest/core/session/prev_words_info.h @@ -18,14 +18,12 @@ #define LATINIME_PREV_WORDS_INFO_H #include "defines.h" -#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" #include "utils/int_array_view.h" namespace latinime { -// TODO: Support n-gram. class PrevWordsInfo { public: // No prev word information. @@ -81,11 +79,10 @@ class PrevWordsInfo { return false; } - void getPrevWordsTerminalPtNodePos( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - int *const outPrevWordsTerminalPtNodePos, const bool tryLowerCaseSearch) const { + void getPrevWordIds(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + int *const outPrevWordIds, const bool tryLowerCaseSearch) const { for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - outPrevWordsTerminalPtNodePos[i] = getTerminalPtNodePosOfWord(dictStructurePolicy, + outPrevWordIds[i] = getWordId(dictStructurePolicy, mPrevWordCodePoints[i], mPrevWordCodePointCount[i], mIsBeginningOfSentence[i], tryLowerCaseSearch); } @@ -110,12 +107,11 @@ class PrevWordsInfo { private: DISALLOW_COPY_AND_ASSIGN(PrevWordsInfo); - static int getTerminalPtNodePosOfWord( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, const int *const wordCodePoints, const int wordCodePointCount, const bool isBeginningOfSentence, const bool tryLowerCaseSearch) { if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { - return NOT_A_DICT_POS; + return NOT_A_WORD_ID; } int codePoints[MAX_WORD_LENGTH]; int codePointCount = wordCodePointCount; @@ -124,21 +120,19 @@ class PrevWordsInfo { codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, codePointCount, MAX_WORD_LENGTH); if (codePointCount <= 0) { - return NOT_A_DICT_POS; + return NOT_A_WORD_ID; } } const CodePointArrayView codePointArrayView(codePoints, codePointCount); - const int wordPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord( + const int wordId = dictStructurePolicy->getWordId( codePointArrayView, false /* forceLowerCaseSearch */); - if (wordPtNodePos != NOT_A_DICT_POS || !tryLowerCaseSearch) { - // Return the position when when the word was found or doesn't try lower case - // search. - return wordPtNodePos; + if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) { + // Return the id when when the word was found or doesn't try lower case search. + return wordId; } // Check bigrams for lower-cased previous word if original was not found. Useful for // auto-capitalized words like "The [current_word]". - return dictStructurePolicy->getTerminalPtNodePositionOfWord( - codePointArrayView, true /* forceLowerCaseSearch */); + return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */); } void clear() { diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 0cd305f5a..66c87f04c 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -92,7 +92,7 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession) const { // Create a new dic node here DicNode rootNode; DicNodeUtils::initAsRoot(traverseSession->getDictionaryStructurePolicy(), - traverseSession->getPrevWordsPtNodePos(), &rootNode); + traverseSession->getPrevWordIds(), &rootNode); traverseSession->getDicTraverseCache()->copyPushActive(&rootNode); } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 9f6ae114d..5dff1fc97 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -104,7 +104,7 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( return codePointCount; } -int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, +int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); @@ -112,9 +112,9 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArray wordCodePoints.size(), forceLowerCaseSearch); if (readingHelper.isError()) { mIsCorrupted = true; - AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); + AKLOGE("Dictionary reading error in getWordId()."); } - return ptNodePos; + return getWordIdFromTerminalPtNodePos(ptNodePos); } int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, @@ -133,17 +133,19 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, + const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordsPtNodePos) { - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + if (prevWordIds) { + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); @@ -157,16 +159,18 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtN return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } -void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, +void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const { - if (!prevWordsPtNodePos) { + if (!prevWordIds) { return; } - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); - listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + listener->onVisitEntry(bigramsIt.getProbability(), + getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos())); } } @@ -238,8 +242,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. - const int wordPos = getTerminalPtNodePositionOfWord(codePointArrayView, - false /* forceLowerCaseSearch */); + const int wordPos = getTerminalPtNodePosFromWordId( + getWordId(codePointArrayView, false /* forceLowerCaseSearch */)); if (wordPos == NOT_A_DICT_POS) { AKLOGE("Cannot find terminal PtNode position to add shortcut target."); return false; @@ -266,8 +270,8 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); return false; } - const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, - false /* forceLowerCaseSearch */); + const int ptNodePos = getTerminalPtNodePosFromWordId( + getWordId(wordCodePoints, false /* forceLowerCaseSearch */)); if (ptNodePos == NOT_A_DICT_POS) { return false; } @@ -295,11 +299,9 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSearch */); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); + if (prevWordIds[0] == NOT_A_WORD_ID) { if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { const std::vector shortcuts; const UnigramProperty beginningOfSentenceUnigramProperty( @@ -311,22 +313,22 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); return false; } - // Refresh Terminal PtNode positions. - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSearch */); + // Refresh word ids. + prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); } else { return false; } } - const int word1Pos = getTerminalPtNodePositionOfWord( + const int wordPos = getTerminalPtNodePosFromWordId(getWordId( CodePointArrayView(*bigramProperty->getTargetCodePoints()), - false /* forceLowerCaseSearch */); - if (word1Pos == NOT_A_DICT_POS) { + false /* forceLowerCaseSearch */)); + if (wordPos == NOT_A_DICT_POS) { return false; } bool addedNewBigram = false; - if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(prevWordsPtNodePos), - word1Pos, bigramProperty, &addedNewBigram)) { + const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); + if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos), + wordPos, bigramProperty, &addedNewBigram)) { if (addedNewBigram) { mBigramCount++; } @@ -355,20 +357,19 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", wordCodePoints.size()); } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSerch */); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */); + if (prevWordIds[0] == NOT_A_WORD_ID) { return false; } - const int wordPos = getTerminalPtNodePositionOfWord(wordCodePoints, - false /* forceLowerCaseSearch */); + const int wordPos = getTerminalPtNodePosFromWordId(getWordId(wordCodePoints, + false /* forceLowerCaseSearch */)); if (wordPos == NOT_A_DICT_POS) { return false; } + const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); if (mUpdatingHelper.removeNgramEntry( - PtNodePosArrayView::fromObject(prevWordsPtNodePos), wordPos)) { + PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) { mBigramCount--; return true; } else { @@ -449,8 +450,8 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer const WordProperty Ver4PatriciaTriePolicy::getWordProperty( const CodePointArrayView wordCodePoints) const { - const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, - false /* forceLowerCaseSearch */); + const int ptNodePos = getTerminalPtNodePosFromWordId( + getWordId(wordCodePoints, false /* forceLowerCaseSearch */)); if (ptNodePos == NOT_A_DICT_POS) { AKLOGE("getWordProperty is called for invalid word."); return WordProperty(); @@ -553,6 +554,14 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const return nextToken; } +int Ver4PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const { + return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos; +} + +int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const { + return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; +} + } // namespace v402 } // namespace backward } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index df119e3a1..2ebe9ba5a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -87,15 +87,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const; - int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, - const bool forceLowerCaseSearch) const; + int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordsPtNodePos, - NgramListener *const listener) const; + void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const; int getShortcutPositionOfPtNode(const int ptNodePos) const; @@ -164,6 +162,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { mutable bool mIsCorrupted; int getBigramsPositionOfPtNode(const int ptNodePos) const; + int getWordIdFromTerminalPtNodePos(const int ptNodePos) const; + int getTerminalPtNodePosFromWordId(const int wordId) const; }; } // namespace v402 } // namespace backward diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index 4ac366e07..85971f1f2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -267,8 +267,8 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( } // This function gets the position of the terminal PtNode of the exact matching word in the -// dictionary. If no match is found, it returns NOT_A_DICT_POS. -int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, +// dictionary. If no match is found, it returns NOT_A_WORD_ID. +int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const { DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); @@ -276,9 +276,9 @@ int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints.size(), forceLowerCaseSearch); if (readingHelper.isError()) { mIsCorrupted = true; - AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); + AKLOGE("Dictionary reading error in getWordId()."); } - return ptNodePos; + return getWordIdFromTerminalPtNodePos(ptNodePos); } int PatriciaTriePolicy::getProbability(const int unigramProbability, @@ -297,11 +297,11 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams = mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) { @@ -310,8 +310,9 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP // for shortcuts). return NOT_A_PROBABILITY; } - if (prevWordsPtNodePos) { - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + if (prevWordIds) { + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); @@ -325,16 +326,18 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } -void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, +void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const { - if (!prevWordsPtNodePos) { + if (!prevWordIds) { return; } - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); - listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + listener->onVisitEntry(bigramsIt.getProbability(), + getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos())); } } @@ -379,12 +382,12 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod const WordProperty PatriciaTriePolicy::getWordProperty( const CodePointArrayView wordCodePoints) const { - const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, - false /* forceLowerCaseSearch */); - if (ptNodePos == NOT_A_DICT_POS) { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { AKLOGE("getWordProperty was called for invalid word."); return WordProperty(); } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams = mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); std::vector codePointVector(ptNodeParams.getCodePoints(), @@ -467,4 +470,11 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC return nextToken; } +int PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const { + return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos; +} + +int PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const { + return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; +} } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 4d9af2877..31fee7742 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -64,15 +64,13 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const; - int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, - const bool forceLowerCaseSearch) const; + int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordsPtNodePos, - NgramListener *const listener) const; + void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const; int getShortcutPositionOfPtNode(const int ptNodePos) const; @@ -163,6 +161,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getBigramsPositionOfPtNode(const int ptNodePos) const; int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos, DicNodeVector *const childDicNodes) const; + int getWordIdFromTerminalPtNodePos(const int ptNodePos) const; + int getTerminalPtNodePosFromWordId(const int wordId) const; }; } // namespace latinime #endif // LATINIME_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 619cdb59b..7024682f6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -94,7 +94,7 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( return codePointCount; } -int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, +int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); @@ -104,7 +104,11 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArray mIsCorrupted = true; AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); } - return ptNodePos; + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_WORD_ID; + } + const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + return ptNodeParams.getTerminalId(); } int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, @@ -123,24 +127,22 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, + const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordsPtNodePos) { + if (prevWordIds) { // TODO: Support n-gram. - const PtNodeParams prevWordPtNodeParams = - mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordsPtNodePos[0]); - const int prevWordTerminalId = prevWordPtNodeParams.getTerminalId(); const ProbabilityEntry probabilityEntry = mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry( - IntArrayView::fromObject(&prevWordTerminalId), - ptNodeParams.getTerminalId()); + IntArrayView::fromObject(prevWordIds), wordId); if (!probabilityEntry.isValid()) { return NOT_A_PROBABILITY; } @@ -154,26 +156,21 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtN return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } -void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, +void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const { - if (!prevWordsPtNodePos) { + if (!prevWordIds) { return; } // TODO: Support n-gram. - const PtNodeParams ptNodeParams = - mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordsPtNodePos[0]); - const int prevWordId = ptNodeParams.getTerminalId(); - const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&prevWordId); const auto languageModelDictContent = mBuffers->getLanguageModelDictContent(); - for (const auto entry : languageModelDictContent->getProbabilityEntries(prevWordIds)) { + for (const auto entry : languageModelDictContent->getProbabilityEntries( + WordIdArrayView::fromObject(prevWordIds))) { const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); const int probability = probabilityEntry.hasHistoricalInfo() ? ForgettingCurveUtils::decodeProbability( probabilityEntry.getHistoricalInfo(), mHeaderPolicy) : probabilityEntry.getProbability(); - const int ptNodePos = mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition( - entry.getWordId()); - listener->onVisitEntry(probability, ptNodePos); + listener->onVisitEntry(probability, entry.getWordId()); } } @@ -233,12 +230,13 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. - const int wordPos = getTerminalPtNodePositionOfWord(codePointArrayView, - false /* forceLowerCaseSearch */); - if (wordPos == NOT_A_DICT_POS) { - AKLOGE("Cannot find terminal PtNode position to add shortcut target."); + const int wordId = getWordId(codePointArrayView, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { + AKLOGE("Cannot find word id to add shortcut target."); return false; } + const int wordPos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); for (const auto &shortcut : unigramProperty->getShortcuts()) { if (!mUpdatingHelper.addShortcutTarget(wordPos, shortcut.getTargetCodePoints()->data(), @@ -261,20 +259,19 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); return false; } - const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, - false /* forceLowerCaseSearch */); - if (ptNodePos == NOT_A_DICT_POS) { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { return false; } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); if (!mNodeWriter.markPtNodeAsDeleted(&ptNodeParams)) { AKLOGE("Cannot remove unigram. ptNodePos: %d", ptNodePos); return false; } - if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry( - ptNodeParams.getTerminalId())) { - // TODO: Uncomment. - // return false; + if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry(wordId)) { + return false; } if (!ptNodeParams.representsNonWordInfo()) { mUnigramCount--; @@ -302,12 +299,10 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSearch */); - const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos); + int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (prevWordIds[0] == NOT_A_WORD_ID) { if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { const std::vector shortcuts; const UnigramProperty beginningOfSentenceUnigramProperty( @@ -319,22 +314,27 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); return false; } - // Refresh Terminal PtNode positions. - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSearch */); + // Refresh word ids. + prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); } else { return false; } } - const int word1Pos = getTerminalPtNodePositionOfWord( - CodePointArrayView(*bigramProperty->getTargetCodePoints()), + const int wordId = getWordId(CodePointArrayView(*bigramProperty->getTargetCodePoints()), false /* forceLowerCaseSearch */); - if (word1Pos == NOT_A_DICT_POS) { + if (wordId == NOT_A_WORD_ID) { return false; } bool addedNewEntry = false; - if (mUpdatingHelper.addNgramEntry(prevWordsPtNodePosView, word1Pos, bigramProperty, - &addedNewEntry)) { + int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + for (size_t i = 0; i < NELEMS(prevWordIds); ++i) { + prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable() + ->getTerminalPtNodePosition(prevWordIds[i]); + } + const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable() + ->getTerminalPtNodePosition(wordId); + if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos), + wordPtNodePos, bigramProperty, &addedNewEntry)) { if (addedNewEntry) { mBigramCount++; } @@ -363,20 +363,25 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", wordCodePoints.size()); } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSerch */); - const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos); + int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */); // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (prevWordIds[0] == NOT_A_WORD_ID) { return false; } - const int wordPos = getTerminalPtNodePositionOfWord(wordCodePoints, - false /* forceLowerCaseSearch */); - if (wordPos == NOT_A_DICT_POS) { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { return false; } - if (mUpdatingHelper.removeNgramEntry(prevWordsPtNodePosView, wordPos)) { + int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + for (size_t i = 0; i < NELEMS(prevWordIds); ++i) { + prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable() + ->getTerminalPtNodePosition(prevWordIds[i]); + } + const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable() + ->getTerminalPtNodePosition(wordId); + if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos), + wordPtNodePos)) { mBigramCount--; return true; } else { @@ -457,12 +462,13 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer const WordProperty Ver4PatriciaTriePolicy::getWordProperty( const CodePointArrayView wordCodePoints) const { - const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, - false /* forceLowerCaseSearch */); - if (ptNodePos == NOT_A_DICT_POS) { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { AKLOGE("getWordProperty is called for invalid word."); return WordProperty(); } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); std::vector codePointVector(ptNodeParams.getCodePoints(), ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); @@ -473,7 +479,6 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( // Fetch bigram information. // TODO: Support n-gram. std::vector bigrams; - const int wordId = ptNodeParams.getTerminalId(); const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId); const TerminalPositionLookupTable *const terminalPositionLookupTable = mBuffers->getTerminalPositionLookupTable(); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index 24f92a4aa..1d2712a4b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -66,15 +66,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const; - int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, - const bool forceLowerCaseSearch) const; + int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordsPtNodePos, - NgramListener *const listener) const; + void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const; int getShortcutPositionOfPtNode(const int ptNodePos) const;