diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index e4edc5ab6..b3479738e 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -109,9 +109,7 @@ void UnigramDictionary::registerNextLetter( } } -bool -UnigramDictionary::addWord(unsigned short *word, int length, int frequency) -{ +bool UnigramDictionary::addWord(unsigned short *word, int length, int frequency) { word[length] = 0; if (DEBUG_DICT) { char s[length + 1]; @@ -147,8 +145,7 @@ UnigramDictionary::addWord(unsigned short *word, int length, int frequency) return false; } -unsigned short -UnigramDictionary::toLowerCase(unsigned short c) { +unsigned short UnigramDictionary::toLowerCase(unsigned short c) { if (c < sizeof(BASE_CHARS) / sizeof(BASE_CHARS[0])) { c = BASE_CHARS[c]; } @@ -160,9 +157,7 @@ UnigramDictionary::toLowerCase(unsigned short c) { return c; } -bool -UnigramDictionary::sameAsTyped(unsigned short *word, int length) -{ +bool UnigramDictionary::sameAsTyped(unsigned short *word, int length) { if (length != mInputLength) { return false; } @@ -180,15 +175,10 @@ UnigramDictionary::sameAsTyped(unsigned short *word, int length) static const char QUOTE = '\''; // snr : frequency? -void -UnigramDictionary::getWordsRec(int pos, int depth, int maxDepth, bool completion, int snr, - int inputIndex, int diffs, int skipPos, int *nextLetters, int nextLettersSize) -{ +void UnigramDictionary::getWordsRec(int pos, int depth, int maxDepth, bool traverseAllNodes, + int snr, int inputIndex, int diffs, int skipPos, int *nextLetters, int nextLettersSize) { // Optimization: Prune out words that are too long compared to how much was typed. - if (depth > maxDepth) { - return; - } - if (diffs > mMaxEditDistance) { + if (depth > maxDepth || diffs > mMaxEditDistance) { return; } // get the count of nodes and increment pos. @@ -196,71 +186,59 @@ UnigramDictionary::getWordsRec(int pos, int depth, int maxDepth, bool completion int *currentChars = NULL; // If inputIndex is greater than mInputLength, that means there are no proximity chars. if (mInputLength <= inputIndex) { - completion = true; + traverseAllNodes = true; } else { currentChars = mInputCodes + (inputIndex * MAX_ALTERNATIVES); } - for (int i = 0; i < count; i++) { + for (int i = 0; i < count; ++i) { // -- at char - unsigned short c = Dictionary::getChar(DICT, &pos); + const unsigned short c = Dictionary::getChar(DICT, &pos); // -- at flag/add - unsigned short lowerC = toLowerCase(c); - bool terminal = Dictionary::getTerminal(DICT, &pos); - int childrenAddress = Dictionary::getAddress(DICT, &pos); - const bool needsToContinue = childrenAddress != 0; + const unsigned short lowerC = toLowerCase(c); + const bool terminal = Dictionary::getTerminal(DICT, &pos); + const int childrenAddress = Dictionary::getAddress(DICT, &pos); + int matchedProximityCharId = -1; + const bool needsToTraverseNextNode = childrenAddress != 0; // -- after address or flag int freq = 1; // If terminal, increment pos if (terminal) freq = Dictionary::getFreq(DICT, IS_LATEST_DICT_VERSION, &pos); // -- after add or freq + bool newTraverseAllNodes = traverseAllNodes; + int newSnr = snr; + int newDiffs = diffs; + int newInputIndex = inputIndex; - // If we are only doing completions, no need to look at the typed characters. - if (completion) { + // If we are only doing traverseAllNodes, no need to look at the typed characters. + if (traverseAllNodes || needsToSkipCurrentNode(c, currentChars[0], skipPos, depth)) { mWord[depth] = c; - if (terminal) { + if (traverseAllNodes && terminal) { onTerminalWhenUserTypedLengthIsGreaterThanInputLength(mWord, mInputLength, depth, snr, nextLetters, nextLettersSize, skipPos, freq); } - if (needsToContinue) { - // No need to do proximity suggest any more. - getWordsRec(childrenAddress, depth + 1, maxDepth, true, snr, inputIndex, - diffs, skipPos, nextLetters, nextLettersSize); - } - } else if ((c == QUOTE && currentChars[0] != QUOTE) || skipPos == depth) { - // Skip the ' or other letter and continue deeper - mWord[depth] = c; - if (needsToContinue) { - getWordsRec(childrenAddress, depth + 1, maxDepth, false, snr, inputIndex, - diffs, skipPos, nextLetters, nextLettersSize); - } } else { - int j = 0; - while (currentChars[j] > 0) { - // Move to child node - if (currentChars[j] == lowerC || currentChars[j] == c) { - mWord[depth] = c; - const int addedWeight = j == 0 ? TYPED_LETTER_MULTIPLIER : 1; - const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1; - // If inputIndex is greater than mInputLength, that means there is no - // proximity chars. So, we don't need to check proximity. - if (isSameAsUserTypedLength) { - if (terminal) { - onTerminalWhenUserTypedLengthIsSameAsInputLength(mWord, depth, snr, - skipPos, freq, addedWeight); - } - } - if (needsToContinue) { - getWordsRec(childrenAddress, depth + 1, maxDepth, - isSameAsUserTypedLength, snr * addedWeight, inputIndex + 1, - diffs + (j > 0), skipPos, nextLetters, nextLettersSize); - } - } - ++j; - // If skipPos is defined, not to search proximity collections. - // First char is what user typed. - if (skipPos >= 0) break; + matchedProximityCharId = getMatchedProximityId(currentChars, lowerC, c, skipPos); + if (matchedProximityCharId < 0) continue; + mWord[depth] = c; + // If inputIndex is greater than mInputLength, that means there is no + // proximity chars. So, we don't need to check proximity. + const int addedWeight = matchedProximityCharId == 0 ? TYPED_LETTER_MULTIPLIER : 1; + const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1; + if (isSameAsUserTypedLength && terminal) { + onTerminalWhenUserTypedLengthIsSameAsInputLength(mWord, depth, snr, + skipPos, freq, addedWeight); } + if (!needsToTraverseNextNode) continue; + // Start traversing all nodes after the index exceeds the user typed length + newTraverseAllNodes = isSameAsUserTypedLength; + newSnr *= addedWeight; + newDiffs += (matchedProximityCharId > 0); + ++newInputIndex; + } + if (needsToTraverseNextNode) { + getWordsRec(childrenAddress, depth + 1, maxDepth, newTraverseAllNodes, + newSnr, newInputIndex, newDiffs, skipPos, nextLetters, nextLettersSize); } } } @@ -285,4 +263,29 @@ inline void UnigramDictionary::onTerminalWhenUserTypedLengthIsSameAsInputLength( addWord(word, depth + 1, finalFreq); } } + +inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c, + const unsigned short userTypedChar, const int skipPos, const int depth) { + // Skip the ' or other letter and continue deeper + return (c == QUOTE && userTypedChar != QUOTE) || skipPos == depth; +} + +inline int UnigramDictionary::getMatchedProximityId(const int *currentChars, + const unsigned short lowerC, const unsigned short c, const int skipPos) { + bool matched = false; + int j = 0; + while (currentChars[j] > 0) { + matched = (currentChars[j] == lowerC || currentChars[j] == c); + // If skipPos is defined, not to search proximity collections. + // First char is what user typed. + if (matched) { + return j; + } else if (skipPos >= 0) { + return -1; + } + ++j; + } + return -1; +} + } // namespace latinime diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h index 118d7dc29..259276cea 100644 --- a/native/src/unigram_dictionary.h +++ b/native/src/unigram_dictionary.h @@ -53,6 +53,12 @@ private: void onTerminalWhenUserTypedLengthIsSameAsInputLength(unsigned short *word, const int depth, const int snr, const int skipPos, const int freq, const int addedWeight); + bool needsToSkipCurrentNode(const unsigned short c, + const unsigned short userTypedChar, const int skipPos, const int depth); + + int getMatchedProximityId(const int *currentChars, const unsigned short lowerC, + const unsigned short c, const int skipPos); + const unsigned char *DICT; const int MAX_WORDS; const int MAX_WORD_LENGTH;