diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index 707f1e6fb..d307ba2f8 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -179,7 +179,8 @@ bool UnigramDictionary::sameAsTyped(unsigned short *word, int length) { static const char QUOTE = '\''; -void UnigramDictionary::getWords(const int initialPos, const int inputLength, const int skipPos, +// Keep this for comparing spec to new getWords +void UnigramDictionary::getWordsOld(const int initialPos, const int inputLength, const int skipPos, int *nextLetters, const int nextLettersSize) { int initialPosition = initialPos; const int count = Dictionary::getCount(DICT, &initialPosition); @@ -188,6 +189,55 @@ void UnigramDictionary::getWords(const int initialPos, const int inputLength, co mInputLength <= 0, 1, 0, 0, skipPos, nextLetters, nextLettersSize); } +void UnigramDictionary::getWords(const int rootPos, const int inputLength, const int skipPos, + int *nextLetters, const int nextLettersSize) { + int rootPosition = rootPos; + const int MAX_DEPTH = min(inputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH); + // Get the number of child of root, then increment the position + int childCount = Dictionary::getCount(DICT, &rootPosition); + int depth = 0; + + mStackChildCount[0] = childCount; + mStackTraverseAll[0] = (mInputLength <= 0); + mStackNodeFreq[0] = 1; + mStackInputIndex[0] = 0; + mStackDiffs[0] = 0; + mStackSiblingPos[0] = rootPosition; + + while (depth >= 0) { + if (mStackChildCount[depth] > 0) { + --mStackChildCount[depth]; + bool traverseAllNodes = mStackTraverseAll[depth]; + int snr = mStackNodeFreq[depth]; + int inputIndex = mStackInputIndex[depth]; + int diffs = mStackDiffs[depth]; + int siblingPos = mStackSiblingPos[depth]; + int firstChildPos; + // depth will never be greater than MAX_DEPTH because in that case, + // needsToTraverseChildrenNodes should be false + const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, depth, + MAX_DEPTH, traverseAllNodes, snr, inputIndex, diffs, skipPos, nextLetters, + nextLettersSize, &childCount, &firstChildPos, &traverseAllNodes, &snr, + &inputIndex, &diffs, &siblingPos); + // Next sibling pos + mStackSiblingPos[depth] = siblingPos; + if (needsToTraverseChildrenNodes) { + // Goes to child node + ++depth; + mStackChildCount[depth] = childCount; + mStackTraverseAll[depth] = traverseAllNodes; + mStackNodeFreq[depth] = snr; + mStackInputIndex[depth] = inputIndex; + mStackDiffs[depth] = diffs; + mStackSiblingPos[depth] = firstChildPos; + } + } else { + // Goes to parent node + --depth; + } + } +} + // snr : frequency? void UnigramDictionary::getWordsRec(const int childrenCount, const int pos, const int depth, const int maxDepth, const bool traverseAllNodes, const int snr, const int inputIndex, @@ -196,7 +246,7 @@ void UnigramDictionary::getWordsRec(const int childrenCount, const int pos, cons for (int i = 0; i < childrenCount; ++i) { int newCount; int newChildPosition; - int newDepth; + const int newDepth = depth + 1; bool newTraverseAllNodes; int newSnr; int newInputIndex; @@ -204,7 +254,7 @@ void UnigramDictionary::getWordsRec(const int childrenCount, const int pos, cons int newSiblingPos; const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, depth, maxDepth, traverseAllNodes, snr, inputIndex, diffs, skipPos, nextLetters, nextLettersSize, - &newCount, &newChildPosition, &newDepth, &newTraverseAllNodes, &newSnr, + &newCount, &newChildPosition, &newTraverseAllNodes, &newSnr, &newInputIndex, &newDiffs, &newSiblingPos); siblingPos = newSiblingPos; @@ -264,7 +314,7 @@ inline int UnigramDictionary::getMatchedProximityId(const int *currentChars, inline bool UnigramDictionary::processCurrentNode(const int pos, const int depth, const int maxDepth, const bool traverseAllNodes, const int snr, const int inputIndex, const int diffs, const int skipPos, int *nextLetters, const int nextLettersSize, - int *newCount, int *newChildPosition, int *newDepth, bool *newTraverseAllNodes, + int *newCount, int *newChildPosition, bool *newTraverseAllNodes, int *newSnr, int*newInputIndex, int *newDiffs, int *nextSiblingPosition) { unsigned short c; int childPosition; @@ -287,7 +337,6 @@ inline bool UnigramDictionary::processCurrentNode(const int pos, const int depth *newSnr = snr; *newDiffs = diffs; *newInputIndex = inputIndex; - *newDepth = depth + 1; } else { int *currentChars = mInputCodes + (inputIndex * MAX_ALTERNATIVES); int matchedProximityCharId = getMatchedProximityId(currentChars, c, skipPos); @@ -307,10 +356,9 @@ inline bool UnigramDictionary::processCurrentNode(const int pos, const int depth *newSnr = snr * addedWeight; *newDiffs = diffs + (matchedProximityCharId > 0); *newInputIndex = inputIndex + 1; - *newDepth = depth + 1; } // Optimization: Prune out words that are too long compared to how much was typed. - if (*newDepth > maxDepth || *newDiffs > mMaxEditDistance) { + if (depth >= maxDepth || *newDiffs > mMaxEditDistance) { return false; } diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h index c02d366de..c53e77c0d 100644 --- a/native/src/unigram_dictionary.h +++ b/native/src/unigram_dictionary.h @@ -44,8 +44,11 @@ private: void getWordsRec(const int childrenCount, const int pos, const int depth, const int maxDepth, const bool traverseAllNodes, const int snr, const int inputIndex, const int diffs, const int skipPos, int *nextLetters, const int nextLettersSize); - void getWords(const int initialPos, const int inputLength, const int skipPos, int *nextLetters, - const int nextLettersSize); + void getWords(const int rootPos, const int inputLength, const int skipPos, + int *nextLetters, const int nextLettersSize); + // Keep getWordsOld for comparing performance between getWords and getWordsOld + void getWordsOld(const int initialPos, const int inputLength, const int skipPos, + int *nextLetters, const int nextLettersSize); void registerNextLetter(unsigned short c, int *nextLetters, int nextLettersSize); void onTerminalWhenUserTypedLengthIsGreaterThanInputLength(unsigned short *word, const int mInputLength, const int depth, const int snr, int *nextLetters, @@ -58,7 +61,7 @@ private: bool processCurrentNode(const int pos, const int depth, const int maxDepth, const bool traverseAllNodes, const int snr, const int inputIndex, const int diffs, const int skipPos, int *nextLetters, const int nextLettersSize, - int *newCount, int *newChildPosition, int *newDepth, bool *newTraverseAllNodes, + int *newCount, int *newChildPosition, bool *newTraverseAllNodes, int *newSnr, int*newInputIndex, int *newDiffs, int *nextSiblingPosition); const unsigned char *DICT; const int MAX_WORDS; @@ -75,6 +78,13 @@ private: // MAX_WORD_LENGTH_INTERNAL must be bigger than MAX_WORD_LENGTH unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; int mMaxEditDistance; + + int mStackChildCount[MAX_WORD_LENGTH_INTERNAL]; + bool mStackTraverseAll[MAX_WORD_LENGTH_INTERNAL]; + int mStackNodeFreq[MAX_WORD_LENGTH_INTERNAL]; + int mStackInputIndex[MAX_WORD_LENGTH_INTERNAL]; + int mStackDiffs[MAX_WORD_LENGTH_INTERNAL]; + int mStackSiblingPos[MAX_WORD_LENGTH_INTERNAL]; }; // ----------------------------------------------------------------------------