From 8876b75ca1c218949539dcc2fb6c88a19da9e3f8 Mon Sep 17 00:00:00 2001 From: satok Date: Thu, 4 Aug 2011 18:31:57 +0900 Subject: [PATCH] Move scoring part to the correction state Change-Id: I2dc4a0869636fce5526f48b3a6267b6bdf61dbfb --- native/src/correction_state.cpp | 131 ++++++++++++++++-- native/src/correction_state.h | 71 +++++++--- native/src/unigram_dictionary.cpp | 219 ++++++++++-------------------- native/src/unigram_dictionary.h | 12 +- 4 files changed, 248 insertions(+), 185 deletions(-) diff --git a/native/src/correction_state.cpp b/native/src/correction_state.cpp index b2c77b00d..9000e9e9c 100644 --- a/native/src/correction_state.cpp +++ b/native/src/correction_state.cpp @@ -25,13 +25,31 @@ namespace latinime { +////////////////////// +// inline functions // +////////////////////// +static const char QUOTE = '\''; + +inline bool CorrectionState::needsToSkipCurrentNode(const unsigned short c) { + const unsigned short userTypedChar = mProximityInfo->getPrimaryCharAt(mInputIndex); + // Skip the ' or other letter and continue deeper + return (c == QUOTE && userTypedChar != QUOTE) || mSkipPos == mOutputIndex; +} + +///////////////////// +// CorrectionState // +///////////////////// + CorrectionState::CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier) : TYPED_LETTER_MULTIPLIER(typedLetterMultiplier), FULL_WORD_MULTIPLIER(fullWordMultiplier) { } -void CorrectionState::initCorrectionState(const ProximityInfo *pi, const int inputLength) { +void CorrectionState::initCorrectionState(const ProximityInfo *pi, const int inputLength, + const int maxDepth) { mProximityInfo = pi; mInputLength = inputLength; + mMaxDepth = maxDepth; + mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2; } void CorrectionState::setCorrectionParams(const int skipPos, const int excessivePos, @@ -58,27 +76,37 @@ int CorrectionState::getFreqForSplitTwoWords(const int firstFreq, const int seco return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this); } -int CorrectionState::getFinalFreq(const unsigned short *word, const int freq) { - if (mProximityInfo->sameAsTyped(word, mOutputIndex + 1) || mOutputIndex < MIN_SUGGEST_DEPTH) { +int CorrectionState::getFinalFreq(const int freq, unsigned short **word, int *wordLength) { + const int outputIndex = mOutputIndex - 1; + const int inputIndex = (mCurrentStateType == TRAVERSE_ALL_ON_TERMINAL + || mCurrentStateType == TRAVERSE_ALL_NOT_ON_TERMINAL) ? mInputIndex : mInputIndex - 1; + *wordLength = outputIndex + 1; + if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) { return -1; } - const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == mInputIndex + 2) - : (mInputLength == mInputIndex + 1); + *word = mWord; + const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2) + : (mInputLength == inputIndex + 1); return CorrectionState::RankingAlgorithm::calculateFinalFreq( - mInputIndex, mOutputIndex, mMatchedCharCount, freq, sameLength, this); + inputIndex, outputIndex, mMatchedCharCount, freq, sameLength, this); } -void CorrectionState::initProcessState( - const int matchCount, const int inputIndex, const int outputIndex) { +void CorrectionState::initProcessState(const int matchCount, const int inputIndex, + const int outputIndex, const bool traverseAllNodes, const int diffs) { mMatchedCharCount = matchCount; mInputIndex = inputIndex; mOutputIndex = outputIndex; + mTraverseAllNodes = traverseAllNodes; + mDiffs = diffs; } -void CorrectionState::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex) { +void CorrectionState::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex, + bool *traverseAllNodes, int *diffs) { *matchedCount = mMatchedCharCount; *inputIndex = mInputIndex; *outputIndex = mOutputIndex; + *traverseAllNodes = mTraverseAllNodes; + *diffs = mDiffs; } void CorrectionState::charMatched() { @@ -95,6 +123,11 @@ int CorrectionState::getInputIndex() { return mInputIndex; } +// TODO: remove +bool CorrectionState::needsToTraverseAll() { + return mTraverseAllNodes; +} + void CorrectionState::incrementInputIndex() { ++mInputIndex; } @@ -103,6 +136,86 @@ void CorrectionState::incrementOutputIndex() { ++mOutputIndex; } +void CorrectionState::startTraverseAll() { + mTraverseAllNodes = true; +} + +bool CorrectionState::needsToPrune() const { + return (mOutputIndex - 1 >= (mTransposedPos >= 0 ? mInputLength - 1 : mMaxDepth) + || mDiffs > mMaxEditDistance); +} + +CorrectionState::CorrectionStateType CorrectionState::processCharAndCalcState( + const int32_t c, const bool isTerminal) { + mCurrentStateType = NOT_ON_TERMINAL; + // This has to be done for each virtual char (this forwards the "inputIndex" which + // is the index in the user-inputted chars, as read by proximity chars. + if (mExcessivePos == mOutputIndex && mInputIndex < mInputLength - 1) { + incrementInputIndex(); + } + + if (mTraverseAllNodes || needsToSkipCurrentNode(c)) { + mWord[mOutputIndex] = c; + if (needsToTraverseAll() && isTerminal) { + mCurrentStateType = TRAVERSE_ALL_ON_TERMINAL; + } else { + mCurrentStateType = TRAVERSE_ALL_NOT_ON_TERMINAL; + } + } else { + int inputIndexForProximity = mInputIndex; + + if (mTransposedPos >= 0) { + if (mInputIndex == mTransposedPos) { + ++inputIndexForProximity; + } + if (mInputIndex == (mTransposedPos + 1)) { + --inputIndexForProximity; + } + } + + int matchedProximityCharId = mProximityInfo->getMatchedProximityId( + inputIndexForProximity, c, this); + if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) { + mCurrentStateType = UNRELATED; + return mCurrentStateType; + } + mWord[mOutputIndex] = c; + // If inputIndex is greater than mInputLength, that means there is no + // proximity chars. So, we don't need to check proximity. + if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { + charMatched(); + } + + if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) { + incrementDiffs(); + } + + const bool isSameAsUserTypedLength = mInputLength + == getInputIndex() + 1 + || (mExcessivePos == mInputLength - 1 + && getInputIndex() == mInputLength - 2); + if (isSameAsUserTypedLength && isTerminal) { + mCurrentStateType = ON_TERMINAL; + } + // Start traversing all nodes after the index exceeds the user typed length + if (isSameAsUserTypedLength) { + startTraverseAll(); + } + + // Finally, we are ready to go to the next character, the next "virtual node". + // We should advance the input index. + // We do this in this branch of the 'if traverseAllNodes' because we are still matching + // characters to input; the other branch is not matching them but searching for + // completions, this is why it does not have to do it. + incrementInputIndex(); + } + + // Also, the next char is one "virtual node" depth more than this char. + incrementOutputIndex(); + + return mCurrentStateType; +} + CorrectionState::~CorrectionState() { } diff --git a/native/src/correction_state.h b/native/src/correction_state.h index cc3c3e669..a548bcb68 100644 --- a/native/src/correction_state.h +++ b/native/src/correction_state.h @@ -29,49 +29,76 @@ class CorrectionState { public: typedef enum { - ALLOW_ALL, + TRAVERSE_ALL_ON_TERMINAL, + TRAVERSE_ALL_NOT_ON_TERMINAL, UNRELATED, - RELATED + ON_TERMINAL, + NOT_ON_TERMINAL } CorrectionStateType; CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier); - void initCorrectionState(const ProximityInfo *pi, const int inputLength); + void initCorrectionState( + const ProximityInfo *pi, const int inputLength, const int maxWordLength); void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, const int spaceProximityPos, const int missingSpacePos); void checkState(); - void initProcessState(const int matchCount, const int inputIndex, const int outputIndex); - void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex); - void charMatched(); - void incrementInputIndex(); - void incrementOutputIndex(); + void initProcessState(const int matchCount, const int inputIndex, const int outputIndex, + const bool traverseAllNodes, const int diffs); + void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex, + bool *traverseAllNodes, int *diffs); int getOutputIndex(); int getInputIndex(); + bool needsToTraverseAll(); virtual ~CorrectionState(); - int getSkipPos() const { - return mSkipPos; - } - int getExcessivePos() const { - return mExcessivePos; - } - int getTransposedPos() const { - return mTransposedPos; - } int getSpaceProximityPos() const { return mSpaceProximityPos; } int getMissingSpacePos() const { return mMissingSpacePos; } - int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq); - int getFinalFreq(const unsigned short *word, const int freq); + int getSkipPos() const { + return mSkipPos; + } + + int getExcessivePos() const { + return mExcessivePos; + } + + int getTransposedPos() const { + return mTransposedPos; + } + + bool needsToPrune() const; + + int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq); + int getFinalFreq(const int freq, unsigned short **word, int* wordLength); + + CorrectionStateType processCharAndCalcState(const int32_t c, const bool isTerminal); + + int getDiffs() const { + return mDiffs; + } private: + void charMatched(); + void incrementInputIndex(); + void incrementOutputIndex(); + void startTraverseAll(); + + // TODO: remove + + void incrementDiffs() { + ++mDiffs; + } const int TYPED_LETTER_MULTIPLIER; const int FULL_WORD_MULTIPLIER; const ProximityInfo *mProximityInfo; + + int mMaxEditDistance; + int mMaxDepth; int mInputLength; int mSkipPos; int mExcessivePos; @@ -82,6 +109,12 @@ private: int mMatchedCharCount; int mInputIndex; int mOutputIndex; + int mDiffs; + bool mTraverseAllNodes; + CorrectionStateType mCurrentStateType; + unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; + + inline bool needsToSkipCurrentNode(const unsigned short c); class RankingAlgorithm { public: diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index b95da99a3..93d2b8418 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -181,14 +181,14 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, PROF_START(0); initSuggestions( proximityInfo, xcoordinates, ycoordinates, codes, codesSize, outWords, frequencies); - mCorrectionState->initCorrectionState(mProximityInfo, mInputLength); if (DEBUG_DICT) assert(codesSize == mInputLength); - const int MAX_DEPTH = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH); + const int maxDepth = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH); + mCorrectionState->initCorrectionState(mProximityInfo, mInputLength, maxDepth); PROF_END(0); PROF_START(1); - getSuggestionCandidates(-1, -1, -1, MAX_DEPTH); + getSuggestionCandidates(-1, -1, -1); PROF_END(1); PROF_START(2); @@ -198,7 +198,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest missing characters %d", i); } - getSuggestionCandidates(i, -1, -1, MAX_DEPTH); + getSuggestionCandidates(i, -1, -1); } } PROF_END(2); @@ -211,7 +211,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest excessive characters %d", i); } - getSuggestionCandidates(-1, i, -1, MAX_DEPTH); + getSuggestionCandidates(-1, i, -1); } } PROF_END(3); @@ -224,7 +224,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest transposed characters %d", i); } - getSuggestionCandidates(-1, -1, i, mInputLength - 1); + getSuggestionCandidates(-1, -1, i); } } PROF_END(4); @@ -272,7 +272,6 @@ void UnigramDictionary::initSuggestions(ProximityInfo *proximityInfo, const int mFrequencies = frequencies; mOutputChars = outWords; mInputLength = codesSize; - mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2; proximityInfo->setInputParams(codes, codesSize); mProximityInfo = proximityInfo; } @@ -342,9 +341,8 @@ static const char QUOTE = '\''; static const char SPACE = ' '; void UnigramDictionary::getSuggestionCandidates(const int skipPos, - const int excessivePos, const int transposedPos, const int maxDepth) { + const int excessivePos, const int transposedPos) { if (DEBUG_DICT) { - LOGI("getSuggestionCandidates %d", maxDepth); assert(transposedPos + 1 < mInputLength); assert(excessivePos < mInputLength); assert(missingPos < mInputLength); @@ -368,32 +366,26 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, while (depth >= 0) { if (mStackChildCount[depth] > 0) { --mStackChildCount[depth]; - bool traverseAllNodes = mStackTraverseAll[depth]; - int diffs = mStackDiffs[depth]; int siblingPos = mStackSiblingPos[depth]; int firstChildPos; mCorrectionState->initProcessState( - mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth]); + mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth], + mStackTraverseAll[depth], mStackDiffs[depth]); - // depth will never be greater than maxDepth because in that case, // needsToTraverseChildrenNodes should be false const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, - maxDepth, traverseAllNodes, diffs, - mCorrectionState, &childCount, - &firstChildPos, &traverseAllNodes, &diffs, - &siblingPos); + mCorrectionState, &childCount, &firstChildPos, &siblingPos); // Update next sibling pos mStackSiblingPos[depth] = siblingPos; if (needsToTraverseChildrenNodes) { // Goes to child node ++depth; mStackChildCount[depth] = childCount; - mStackTraverseAll[depth] = traverseAllNodes; - mStackDiffs[depth] = diffs; mStackSiblingPos[depth] = firstChildPos; mCorrectionState->getProcessState(&mStackMatchedCount[depth], - &mStackInputIndex[depth], &mStackOutputIndex[depth]); + &mStackInputIndex[depth], &mStackOutputIndex[depth], + &mStackTraverseAll[depth], &mStackDiffs[depth]); } } else { // Goes to parent sibling node @@ -437,12 +429,12 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c, return (c == QUOTE && userTypedChar != QUOTE) || skipPos == depth; } - -inline void UnigramDictionary::onTerminal( - unsigned short int* word, const int freq, CorrectionState *correctionState) { - const int finalFreq = correctionState->getFinalFreq(word, freq); +inline void UnigramDictionary::onTerminal(const int freq, CorrectionState *correctionState) { + int wordLength; + unsigned short* wordPointer; + const int finalFreq = correctionState->getFinalFreq(freq, &wordPointer, &wordLength); if (finalFreq >= 0) { - addWord(word, correctionState->getOutputIndex() + 1, finalFreq); + addWord(wordPointer, wordLength, finalFreq); } } @@ -657,20 +649,13 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs // there aren't any more nodes at this level, it merely returns the address of the first byte after // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any // given level, as output into newCount when traversing this level's parent. -inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int maxDepth, - const bool initialTraverseAllNodes, const int initialDiffs, - CorrectionState *correctionState, int *newCount, int *newChildrenPosition, - bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition) { - const int skipPos = correctionState->getSkipPos(); - const int excessivePos = correctionState->getExcessivePos(); - const int transposedPos = correctionState->getTransposedPos(); +inline bool UnigramDictionary::processCurrentNode(const int initialPos, + CorrectionState *correctionState, int *newCount, + int *newChildrenPosition, int *nextSiblingPosition) { if (DEBUG_DICT) { correctionState->checkState(); } int pos = initialPos; - int traverseAllNodes = initialTraverseAllNodes; - int diffs = initialDiffs; - const int initialInputIndex = correctionState->getInputIndex(); // Flags contain the following information: // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits: @@ -682,6 +667,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // - FLAG_HAS_BIGRAMS: whether this node has bigrams or not const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(DICT_ROOT, &pos); const bool hasMultipleChars = (0 != (FLAG_HAS_MULTIPLE_CHARS & flags)); + const bool isTerminalNode = (0 != (FLAG_IS_TERMINAL & flags)); + + bool needsToInvokeOnTerminal = false; // This gets only ONE character from the stream. Next there will be: // if FLAG_HAS_MULTIPLE CHARS: the other characters of the same node @@ -707,111 +695,21 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in const bool isLastChar = (NOT_A_CHARACTER == nextc); // If there are more chars in this nodes, then this virtual node is not a terminal. // If we are on the last char, this virtual node is a terminal if this node is. - const bool isTerminal = isLastChar && (0 != (FLAG_IS_TERMINAL & flags)); - // If there are more chars in this node, then this virtual node has children. - // If we are on the last char, this virtual node has children if this node has. - const bool hasChildren = (!isLastChar) || BinaryFormat::hasChildrenInFlags(flags); + const bool isTerminal = isLastChar && isTerminalNode; - // This has to be done for each virtual char (this forwards the "inputIndex" which - // is the index in the user-inputted chars, as read by proximity chars. - if (excessivePos == correctionState->getOutputIndex() - && correctionState->getInputIndex() < mInputLength - 1) { - correctionState->incrementInputIndex(); - } - if (traverseAllNodes || needsToSkipCurrentNode( - c, correctionState->getInputIndex(), skipPos, correctionState->getOutputIndex())) { - mWord[correctionState->getOutputIndex()] = c; - if (traverseAllNodes && isTerminal) { - // The frequency should be here, because we come here only if this is actually - // a terminal node, and we are on its last char. - const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, freq, mCorrectionState); - } - if (!hasChildren) { - // If we don't have children here, that means we finished processing all - // characters of this node (we are on the last virtual node), AND we are in - // traverseAllNodes mode, which means we are searching for *completions*. We - // should skip the frequency if we have a terminal, and report the position - // of the next sibling. We don't have to return other values because we are - // returning false, as in "don't traverse children". - if (isTerminal) pos = BinaryFormat::skipFrequency(flags, pos); - *nextSiblingPosition = - BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - return false; - } - } else { - int inputIndexForProximity = correctionState->getInputIndex(); - - if (transposedPos >= 0) { - if (correctionState->getInputIndex() == transposedPos) { - ++inputIndexForProximity; - } - if (correctionState->getInputIndex() == (transposedPos + 1)) { - --inputIndexForProximity; - } - } - - int matchedProximityCharId = mProximityInfo->getMatchedProximityId( - inputIndexForProximity, c, mCorrectionState); - if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) { - // We found that this is an unrelated character, so we should give up traversing - // this node and its children entirely. - // However we may not be on the last virtual node yet so we skip the remaining - // characters in this node, the frequency if it's there, read the next sibling - // position to output it, then return false. - // We don't have to output other values because we return false, as in - // "don't traverse children". - if (!isLastChar) { - pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos); - } - pos = BinaryFormat::skipFrequency(flags, pos); - *nextSiblingPosition = - BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - return false; - } - mWord[correctionState->getOutputIndex()] = c; - // If inputIndex is greater than mInputLength, that means there is no - // proximity chars. So, we don't need to check proximity. - if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { - correctionState->charMatched(); - } - const bool isSameAsUserTypedLength = mInputLength - == correctionState->getInputIndex() + 1 - || (excessivePos == mInputLength - 1 - && correctionState->getInputIndex() == mInputLength - 2); - if (isSameAsUserTypedLength && isTerminal) { - const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, freq, mCorrectionState); - } - // Start traversing all nodes after the index exceeds the user typed length - traverseAllNodes = isSameAsUserTypedLength; - diffs = diffs - + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0); - // Finally, we are ready to go to the next character, the next "virtual node". - // We should advance the input index. - // We do this in this branch of the 'if traverseAllNodes' because we are still matching - // characters to input; the other branch is not matching them but searching for - // completions, this is why it does not have to do it. - correctionState->incrementInputIndex(); - - // This character matched the typed character (enough to traverse the node at least) - // so we just evaluated it. Now we should evaluate this virtual node's children - that - // is, if it has any. If it has no children, we're done here - so we skip the end of - // the node, output the siblings position, and return false "don't traverse children". - // Note that !hasChildren implies isLastChar, so we know we don't have to skip any - // remaining char in this group for there can't be any. - if (!hasChildren) { - pos = BinaryFormat::skipFrequency(flags, pos); - *nextSiblingPosition = - BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - return false; - } - } - // Optimization: Prune out words that are too long compared to how much was typed. - if (isTerminal - && (correctionState->getOutputIndex() >= maxDepth || diffs > mMaxEditDistance)) { - // We are giving up parsing this node and its children. Skip the rest of the node, - // output the sibling position, and return that we don't want to traverse children. + CorrectionState::CorrectionStateType stateType = correctionState->processCharAndCalcState( + c, isTerminal); + if (stateType == CorrectionState::TRAVERSE_ALL_ON_TERMINAL + || stateType == CorrectionState::ON_TERMINAL) { + needsToInvokeOnTerminal = true; + } else if (stateType == CorrectionState::UNRELATED) { + // We found that this is an unrelated character, so we should give up traversing + // this node and its children entirely. + // However we may not be on the last virtual node yet so we skip the remaining + // characters in this node, the frequency if it's there, read the next sibling + // position to output it, then return false. + // We don't have to output other values because we return false, as in + // "don't traverse children". if (!isLastChar) { pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos); } @@ -820,8 +718,6 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); return false; } - // Also, the next char is one "virtual node" depth more than this char. - correctionState->incrementOutputIndex(); // Prepare for the next character. Promote the prefetched char to current char - the loop // will take care of prefetching the next. If we finally found our last char, nextc will @@ -829,16 +725,39 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in c = nextc; } while (NOT_A_CHARACTER != c); - // If inputIndex is greater than mInputLength, that means there are no proximity chars. - // Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength. - if (mInputLength <= initialInputIndex) { - traverseAllNodes = true; - } + if (isTerminalNode) { + if (needsToInvokeOnTerminal) { + // The frequency should be here, because we come here only if this is actually + // a terminal node, and we are on its last char. + const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); + onTerminal(freq, mCorrectionState); + } - // All the output values that are purely computation by this function are held in local - // variables. Output them to the caller. - *newTraverseAllNodes = traverseAllNodes; - *newDiffs = diffs; + // If there are more chars in this node, then this virtual node has children. + // If we are on the last char, this virtual node has children if this node has. + const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); + + // This character matched the typed character (enough to traverse the node at least) + // so we just evaluated it. Now we should evaluate this virtual node's children - that + // is, if it has any. If it has no children, we're done here - so we skip the end of + // the node, output the siblings position, and return false "don't traverse children". + // Note that !hasChildren implies isLastChar, so we know we don't have to skip any + // remaining char in this group for there can't be any. + if (!hasChildren) { + pos = BinaryFormat::skipFrequency(flags, pos); + *nextSiblingPosition = + BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); + return false; + } + + // Optimization: Prune out words that are too long compared to how much was typed. + if (correctionState->needsToPrune()) { + pos = BinaryFormat::skipFrequency(flags, pos); + *nextSiblingPosition = + BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); + return false; + } + } // Now we finished processing this node, and we want to traverse children. If there are no // children, we can't come here. diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h index cb86da41c..a45df24fb 100644 --- a/native/src/unigram_dictionary.h +++ b/native/src/unigram_dictionary.h @@ -87,21 +87,20 @@ private: const int *ycoordinates, const int *codes, const int codesSize, unsigned short *outWords, int *frequencies); void getSuggestionCandidates(const int skipPos, const int excessivePos, - const int transposedPos, const int maxDepth); + const int transposedPos); bool addWord(unsigned short *word, int length, int frequency); void getSplitTwoWordsSuggestion(const int inputLength, CorrectionState *correctionState); void getMissingSpaceWords( const int inputLength, const int missingSpacePos, CorrectionState *correctionState); void getMistypedSpaceWords( const int inputLength, const int spaceProximityPos, CorrectionState *correctionState); - void onTerminal(unsigned short int* word, const int freq, CorrectionState *correctionState); + void onTerminal(const int freq, CorrectionState *correctionState); bool needsToSkipCurrentNode(const unsigned short c, const int inputIndex, const int skipPos, const int depth); // Process a node by considering proximity, missing and excessive character - bool processCurrentNode(const int initialPos, const int maxDepth, - const bool initialTraverseAllNodes, const int initialDiffs, - CorrectionState *correctionState, int *newCount, int *newChildPosition, - bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition); + bool processCurrentNode(const int initialPos, + CorrectionState *correctionState, int *newCount, + int *newChildPosition, int *nextSiblingPosition); int getMostFrequentWordLike(const int startInputIndex, const int inputLength, unsigned short *word); int getMostFrequentWordLikeInner(const uint16_t* const inWord, const int length, @@ -134,7 +133,6 @@ private: int mInputLength; // MAX_WORD_LENGTH_INTERNAL must be bigger than MAX_WORD_LENGTH unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; - int mMaxEditDistance; int mStackMatchedCount[MAX_WORD_LENGTH_INTERNAL]; int mStackChildCount[MAX_WORD_LENGTH_INTERNAL];