diff --git a/native/src/correction_state.cpp b/native/src/correction_state.cpp index add6cf673..b2c77b00d 100644 --- a/native/src/correction_state.cpp +++ b/native/src/correction_state.cpp @@ -58,32 +58,49 @@ int CorrectionState::getFreqForSplitTwoWords(const int firstFreq, const int seco return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this); } -int CorrectionState::getFinalFreq(const int inputIndex, const int outputIndex, const int freq) { - const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2) - : (mInputLength == inputIndex + 1); - const int matchCount = mMatchedCharCount; +int CorrectionState::getFinalFreq(const unsigned short *word, const int freq) { + if (mProximityInfo->sameAsTyped(word, mOutputIndex + 1) || mOutputIndex < MIN_SUGGEST_DEPTH) { + return -1; + } + const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == mInputIndex + 2) + : (mInputLength == mInputIndex + 1); return CorrectionState::RankingAlgorithm::calculateFinalFreq( - inputIndex, outputIndex, matchCount, freq, sameLength, this); + mInputIndex, mOutputIndex, mMatchedCharCount, freq, sameLength, this); } -void CorrectionState::initDepth() { - mMatchedCharCount = 0; +void CorrectionState::initProcessState( + const int matchCount, const int inputIndex, const int outputIndex) { + mMatchedCharCount = matchCount; + mInputIndex = inputIndex; + mOutputIndex = outputIndex; +} + +void CorrectionState::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex) { + *matchedCount = mMatchedCharCount; + *inputIndex = mInputIndex; + *outputIndex = mOutputIndex; } void CorrectionState::charMatched() { ++mMatchedCharCount; } -void CorrectionState::goUpTree(const int matchCount) { - mMatchedCharCount = matchCount; +// TODO: remove +int CorrectionState::getOutputIndex() { + return mOutputIndex; } -void CorrectionState::slideTree(const int matchCount) { - mMatchedCharCount = matchCount; +// TODO: remove +int CorrectionState::getInputIndex() { + return mInputIndex; } -void CorrectionState::goDownTree(int *matchedCount) { - *matchedCount = mMatchedCharCount; +void CorrectionState::incrementInputIndex() { + ++mInputIndex; +} + +void CorrectionState::incrementOutputIndex() { + ++mOutputIndex; } CorrectionState::~CorrectionState() { diff --git a/native/src/correction_state.h b/native/src/correction_state.h index 7bbad5f5b..cc3c3e669 100644 --- a/native/src/correction_state.h +++ b/native/src/correction_state.h @@ -28,16 +28,25 @@ class ProximityInfo; class CorrectionState { public: + typedef enum { + ALLOW_ALL, + UNRELATED, + RELATED + } CorrectionStateType; + CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier); void initCorrectionState(const ProximityInfo *pi, const int inputLength); void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, const int spaceProximityPos, const int missingSpacePos); - void initDepth(); void checkState(); - void goUpTree(const int matchCount); - void slideTree(const int matchCount); - void goDownTree(int *matchedCount); + void initProcessState(const int matchCount, const int inputIndex, const int outputIndex); + void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex); void charMatched(); + void incrementInputIndex(); + void incrementOutputIndex(); + int getOutputIndex(); + int getInputIndex(); + virtual ~CorrectionState(); int getSkipPos() const { return mSkipPos; @@ -55,7 +64,7 @@ public: return mMissingSpacePos; } int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq); - int getFinalFreq(const int inputIndex, const int outputIndex, const int freq); + int getFinalFreq(const unsigned short *word, const int freq); private: @@ -71,6 +80,8 @@ private: int mMissingSpacePos; int mMatchedCharCount; + int mInputIndex; + int mOutputIndex; class RankingAlgorithm { public: diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index f5648d3df..9f8f04e50 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -363,27 +363,25 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, mStackSiblingPos[0] = rootPosition; mStackOutputIndex[0] = 0; mStackMatchedCount[0] = 0; - mCorrectionState->initDepth(); // Depth first search while (depth >= 0) { if (mStackChildCount[depth] > 0) { --mStackChildCount[depth]; bool traverseAllNodes = mStackTraverseAll[depth]; - int inputIndex = mStackInputIndex[depth]; int diffs = mStackDiffs[depth]; int siblingPos = mStackSiblingPos[depth]; - int outputIndex = mStackOutputIndex[depth]; int firstChildPos; - mCorrectionState->slideTree(mStackMatchedCount[depth]); + mCorrectionState->initProcessState( + mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth]); // depth will never be greater than maxDepth because in that case, // needsToTraverseChildrenNodes should be false - const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex, - maxDepth, traverseAllNodes, inputIndex, diffs, + const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, + maxDepth, traverseAllNodes, diffs, mCorrectionState, &childCount, - &firstChildPos, &traverseAllNodes, &inputIndex, &diffs, - &siblingPos, &outputIndex); + &firstChildPos, &traverseAllNodes, &diffs, + &siblingPos); // Update next sibling pos mStackSiblingPos[depth] = siblingPos; if (needsToTraverseChildrenNodes) { @@ -391,21 +389,15 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, ++depth; mStackChildCount[depth] = childCount; mStackTraverseAll[depth] = traverseAllNodes; - mStackInputIndex[depth] = inputIndex; mStackDiffs[depth] = diffs; mStackSiblingPos[depth] = firstChildPos; - mStackOutputIndex[depth] = outputIndex; - int matchedCount; - mCorrectionState->goDownTree(&matchedCount); - mStackMatchedCount[depth] = matchedCount; - } else { - mCorrectionState->slideTree(mStackMatchedCount[depth]); + mCorrectionState->getProcessState(&mStackMatchedCount[depth], + &mStackInputIndex[depth], &mStackOutputIndex[depth]); } } else { // Goes to parent sibling node --depth; - mCorrectionState->goUpTree(mStackMatchedCount[depth]); } } } @@ -446,13 +438,11 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c, } -inline void UnigramDictionary::onTerminal(unsigned short int* word, const int outputIndex, - const int inputIndex, const int freq, CorrectionState *correctionState) { - if (!mProximityInfo->sameAsTyped(word, outputIndex + 1) && outputIndex >= MIN_SUGGEST_DEPTH) { - const int finalFreq = correctionState->getFinalFreq(inputIndex, outputIndex, freq); - if (finalFreq >= 0) { - addWord(word, outputIndex + 1, finalFreq); - } +inline void UnigramDictionary::onTerminal( + unsigned short int* word, const int freq, CorrectionState *correctionState) { + const int finalFreq = correctionState->getFinalFreq(word, freq); + if (finalFreq >= 0) { + addWord(word, correctionState->getOutputIndex() + 1, finalFreq); } } @@ -667,12 +657,10 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs // there aren't any more nodes at this level, it merely returns the address of the first byte after // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any // given level, as output into newCount when traversing this level's parent. -inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialOutputPos, - const int maxDepth, const bool initialTraverseAllNodes, int inputIndex, - const int initialDiffs, +inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int maxDepth, + const bool initialTraverseAllNodes, const int initialDiffs, CorrectionState *correctionState, int *newCount, int *newChildrenPosition, - bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs, - int *nextSiblingPosition, int *newOutputIndex) { + bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition) { const int skipPos = correctionState->getSkipPos(); const int excessivePos = correctionState->getExcessivePos(); const int transposedPos = correctionState->getTransposedPos(); @@ -680,9 +668,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in correctionState->checkState(); } int pos = initialPos; - int internalOutputPos = initialOutputPos; int traverseAllNodes = initialTraverseAllNodes; int diffs = initialDiffs; + const int initialInputIndex = correctionState->getInputIndex(); // Flags contain the following information: // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits: @@ -726,16 +714,18 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // This has to be done for each virtual char (this forwards the "inputIndex" which // is the index in the user-inputted chars, as read by proximity chars. - if (excessivePos == internalOutputPos && inputIndex < mInputLength - 1) { - ++inputIndex; + if (excessivePos == correctionState->getOutputIndex() + && correctionState->getInputIndex() < mInputLength - 1) { + correctionState->incrementInputIndex(); } - if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, internalOutputPos)) { - mWord[internalOutputPos] = c; + if (traverseAllNodes || needsToSkipCurrentNode( + c, correctionState->getInputIndex(), skipPos, correctionState->getOutputIndex())) { + mWord[correctionState->getOutputIndex()] = c; if (traverseAllNodes && isTerminal) { // The frequency should be here, because we come here only if this is actually // a terminal node, and we are on its last char. const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState); + onTerminal(mWord, freq, mCorrectionState); } if (!hasChildren) { // If we don't have children here, that means we finished processing all @@ -750,11 +740,15 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in return false; } } else { - int inputIndexForProximity = inputIndex; + int inputIndexForProximity = correctionState->getInputIndex(); if (transposedPos >= 0) { - if (inputIndex == transposedPos) ++inputIndexForProximity; - if (inputIndex == (transposedPos + 1)) --inputIndexForProximity; + if (correctionState->getInputIndex() == transposedPos) { + ++inputIndexForProximity; + } + if (correctionState->getInputIndex() == (transposedPos + 1)) { + --inputIndexForProximity; + } } int matchedProximityCharId = mProximityInfo->getMatchedProximityId( @@ -775,18 +769,31 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); return false; } - mWord[internalOutputPos] = c; + mWord[correctionState->getOutputIndex()] = c; // If inputIndex is greater than mInputLength, that means there is no // proximity chars. So, we don't need to check proximity. if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { correctionState->charMatched(); } - const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1 - || (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2); + const bool isSameAsUserTypedLength = mInputLength + == correctionState->getInputIndex() + 1 + || (excessivePos == mInputLength - 1 + && correctionState->getInputIndex() == mInputLength - 2); if (isSameAsUserTypedLength && isTerminal) { const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState); + onTerminal(mWord, freq, mCorrectionState); } + // Start traversing all nodes after the index exceeds the user typed length + traverseAllNodes = isSameAsUserTypedLength; + diffs = diffs + + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0); + // Finally, we are ready to go to the next character, the next "virtual node". + // We should advance the input index. + // We do this in this branch of the 'if traverseAllNodes' because we are still matching + // characters to input; the other branch is not matching them but searching for + // completions, this is why it does not have to do it. + correctionState->incrementInputIndex(); + // This character matched the typed character (enough to traverse the node at least) // so we just evaluated it. Now we should evaluate this virtual node's children - that // is, if it has any. If it has no children, we're done here - so we skip the end of @@ -799,19 +806,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); return false; } - // Start traversing all nodes after the index exceeds the user typed length - traverseAllNodes = isSameAsUserTypedLength; - diffs = diffs - + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0); - // Finally, we are ready to go to the next character, the next "virtual node". - // We should advance the input index. - // We do this in this branch of the 'if traverseAllNodes' because we are still matching - // characters to input; the other branch is not matching them but searching for - // completions, this is why it does not have to do it. - ++inputIndex; } // Optimization: Prune out words that are too long compared to how much was typed. - if (internalOutputPos >= maxDepth || diffs > mMaxEditDistance) { + if (correctionState->getOutputIndex() >= maxDepth || diffs > mMaxEditDistance) { // We are giving up parsing this node and its children. Skip the rest of the node, // output the sibling position, and return that we don't want to traverse children. if (!isLastChar) { @@ -822,18 +819,18 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); return false; } + // Also, the next char is one "virtual node" depth more than this char. + correctionState->incrementOutputIndex(); // Prepare for the next character. Promote the prefetched char to current char - the loop // will take care of prefetching the next. If we finally found our last char, nextc will // contain NOT_A_CHARACTER. c = nextc; - // Also, the next char is one "virtual node" depth more than this char. - ++internalOutputPos; } while (NOT_A_CHARACTER != c); // If inputIndex is greater than mInputLength, that means there are no proximity chars. // Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength. - if (mInputLength <= *newInputIndex) { + if (mInputLength <= initialInputIndex) { traverseAllNodes = true; } @@ -841,8 +838,6 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // variables. Output them to the caller. *newTraverseAllNodes = traverseAllNodes; *newDiffs = diffs; - *newInputIndex = inputIndex; - *newOutputIndex = internalOutputPos; // Now we finished processing this node, and we want to traverse children. If there are no // children, we can't come here. diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h index c67eaf6e0..cb86da41c 100644 --- a/native/src/unigram_dictionary.h +++ b/native/src/unigram_dictionary.h @@ -94,18 +94,14 @@ private: const int inputLength, const int missingSpacePos, CorrectionState *correctionState); void getMistypedSpaceWords( const int inputLength, const int spaceProximityPos, CorrectionState *correctionState); - void onTerminal(unsigned short int* word, const int depth, - const int inputIndex, const int freq, - CorrectionState *correctionState); + void onTerminal(unsigned short int* word, const int freq, CorrectionState *correctionState); bool needsToSkipCurrentNode(const unsigned short c, const int inputIndex, const int skipPos, const int depth); // Process a node by considering proximity, missing and excessive character - bool processCurrentNode(const int initialPos, const int initialDepth, - const int maxDepth, const bool initialTraverseAllNodes, int inputIndex, - const int initialDiffs, + bool processCurrentNode(const int initialPos, const int maxDepth, + const bool initialTraverseAllNodes, const int initialDiffs, CorrectionState *correctionState, int *newCount, int *newChildPosition, - bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs, - int *nextSiblingPosition, int *nextOutputIndex); + bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition); int getMostFrequentWordLike(const int startInputIndex, const int inputLength, unsigned short *word); int getMostFrequentWordLikeInner(const uint16_t* const inWord, const int length,