diff --git a/native/src/correction_state.cpp b/native/src/correction_state.cpp index fba947ed4..add6cf673 100644 --- a/native/src/correction_state.cpp +++ b/native/src/correction_state.cpp @@ -58,10 +58,32 @@ int CorrectionState::getFreqForSplitTwoWords(const int firstFreq, const int seco return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this); } -int CorrectionState::getFinalFreq(const int inputIndex, const int depth, const int matchWeight, - const int freq, const bool sameLength) { - return CorrectionState::RankingAlgorithm::calculateFinalFreq(inputIndex, depth, matchWeight, - freq, sameLength, this); +int CorrectionState::getFinalFreq(const int inputIndex, const int outputIndex, const int freq) { + const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2) + : (mInputLength == inputIndex + 1); + const int matchCount = mMatchedCharCount; + return CorrectionState::RankingAlgorithm::calculateFinalFreq( + inputIndex, outputIndex, matchCount, freq, sameLength, this); +} + +void CorrectionState::initDepth() { + mMatchedCharCount = 0; +} + +void CorrectionState::charMatched() { + ++mMatchedCharCount; +} + +void CorrectionState::goUpTree(const int matchCount) { + mMatchedCharCount = matchCount; +} + +void CorrectionState::slideTree(const int matchCount) { + mMatchedCharCount = matchCount; +} + +void CorrectionState::goDownTree(int *matchedCount) { + *matchedCount = mMatchedCharCount; } CorrectionState::~CorrectionState() { @@ -117,7 +139,8 @@ inline static void multiplyRate(const int rate, int *freq) { // RankingAlgorithm // ////////////////////// -int CorrectionState::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int depth, +int CorrectionState::RankingAlgorithm::calculateFinalFreq( + const int inputIndex, const int outputIndex, const int matchCount, const int freq, const bool sameLength, const CorrectionState* correctionState) { const int skipPos = correctionState->getSkipPos(); @@ -156,10 +179,10 @@ int CorrectionState::RankingAlgorithm::calculateFinalFreq(const int inputIndex, } } int lengthFreq = typedLetterMultiplier; - multiplyIntCapped(powerIntCapped(typedLetterMultiplier, depth), &lengthFreq); - if (lengthFreq == matchWeight) { + multiplyIntCapped(powerIntCapped(typedLetterMultiplier, outputIndex), &lengthFreq); + if ((outputIndex + 1) == matchCount) { // Full exact match - if (depth > 1) { + if (outputIndex > 1) { if (DEBUG_DICT) { LOGI("Found full matched word."); } @@ -168,7 +191,8 @@ int CorrectionState::RankingAlgorithm::calculateFinalFreq(const int inputIndex, if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0) { finalFreq = capped255MultForFullMatchAccentsOrCapitalizationDifference(finalFreq); } - } else if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0 && depth > 0) { + } else if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0 + && outputIndex > 0) { // A word with proximity corrections if (DEBUG_DICT) { LOGI("Found one proximity correction."); @@ -177,7 +201,7 @@ int CorrectionState::RankingAlgorithm::calculateFinalFreq(const int inputIndex, multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); } if (DEBUG_DICT) { - LOGI("calc: %d, %d", depth, sameLength); + LOGI("calc: %d, %d", outputIndex, sameLength); } if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq); return finalFreq; diff --git a/native/src/correction_state.h b/native/src/correction_state.h index e03b2a17c..7bbad5f5b 100644 --- a/native/src/correction_state.h +++ b/native/src/correction_state.h @@ -32,7 +32,12 @@ public: void initCorrectionState(const ProximityInfo *pi, const int inputLength); void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, const int spaceProximityPos, const int missingSpacePos); + void initDepth(); void checkState(); + void goUpTree(const int matchCount); + void slideTree(const int matchCount); + void goDownTree(int *matchedCount); + void charMatched(); virtual ~CorrectionState(); int getSkipPos() const { return mSkipPos; @@ -50,13 +55,13 @@ public: return mMissingSpacePos; } int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq); - int getFinalFreq(const int inputIndex, const int depth, const int matchWeight, const int freq, - const bool sameLength); + int getFinalFreq(const int inputIndex, const int outputIndex, const int freq); private: const int TYPED_LETTER_MULTIPLIER; const int FULL_WORD_MULTIPLIER; + const ProximityInfo *mProximityInfo; int mInputLength; int mSkipPos; @@ -65,6 +70,8 @@ private: int mSpaceProximityPos; int mMissingSpacePos; + int mMatchedCharCount; + class RankingAlgorithm { public: static int calculateFinalFreq(const int inputIndex, const int depth, diff --git a/native/src/defines.h b/native/src/defines.h index bea83b2c5..5a5d3ee0c 100644 --- a/native/src/defines.h +++ b/native/src/defines.h @@ -176,9 +176,6 @@ static void prof_out(void) { #define MIN_USER_TYPED_LENGTH_FOR_MISSING_SPACE_SUGGESTION 3 #define MIN_USER_TYPED_LENGTH_FOR_EXCESSIVE_CHARACTER_SUGGESTION 3 -// The size of next letters frequency array. Zero will disable the feature. -#define NEXT_LETTERS_SIZE 0 - #define min(a,b) ((a)<(b)?(a):(b)) #endif // LATINIME_DEFINES_H diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index eb28538f1..f5648d3df 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -167,12 +167,6 @@ int UnigramDictionary::getSuggestions(ProximityInfo *proximityInfo, const int *x LOGI("%s %i", s, mFrequencies[j]); #endif } - LOGI("Next letters: "); - for (int k = 0; k < NEXT_LETTERS_SIZE; k++) { - if (mNextLettersFrequency[k] > 0) { - LOGI("%c = %d,", k, mNextLettersFrequency[k]); - } - } } PROF_END(20); PROF_CLOSE; @@ -194,7 +188,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, PROF_END(0); PROF_START(1); - getSuggestionCandidates(-1, -1, -1, mNextLettersFrequency, NEXT_LETTERS_SIZE, MAX_DEPTH); + getSuggestionCandidates(-1, -1, -1, MAX_DEPTH); PROF_END(1); PROF_START(2); @@ -204,7 +198,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest missing characters %d", i); } - getSuggestionCandidates(i, -1, -1, NULL, 0, MAX_DEPTH); + getSuggestionCandidates(i, -1, -1, MAX_DEPTH); } } PROF_END(2); @@ -217,7 +211,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest excessive characters %d", i); } - getSuggestionCandidates(-1, i, -1, NULL, 0, MAX_DEPTH); + getSuggestionCandidates(-1, i, -1, MAX_DEPTH); } } PROF_END(3); @@ -230,7 +224,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest transposed characters %d", i); } - getSuggestionCandidates(-1, -1, i, NULL, 0, mInputLength - 1); + getSuggestionCandidates(-1, -1, i, mInputLength - 1); } } PROF_END(4); @@ -348,8 +342,7 @@ static const char QUOTE = '\''; static const char SPACE = ' '; void UnigramDictionary::getSuggestionCandidates(const int skipPos, - const int excessivePos, const int transposedPos, int *nextLetters, - const int nextLettersSize, const int maxDepth) { + const int excessivePos, const int transposedPos, const int maxDepth) { if (DEBUG_DICT) { LOGI("getSuggestionCandidates %d", maxDepth); assert(transposedPos + 1 < mInputLength); @@ -365,29 +358,31 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, mStackChildCount[0] = childCount; mStackTraverseAll[0] = (mInputLength <= 0); - mStackMatchCount[0] = 0; mStackInputIndex[0] = 0; mStackDiffs[0] = 0; mStackSiblingPos[0] = rootPosition; mStackOutputIndex[0] = 0; + mStackMatchedCount[0] = 0; + mCorrectionState->initDepth(); // Depth first search while (depth >= 0) { if (mStackChildCount[depth] > 0) { --mStackChildCount[depth]; bool traverseAllNodes = mStackTraverseAll[depth]; - int matchCount = mStackMatchCount[depth]; int inputIndex = mStackInputIndex[depth]; int diffs = mStackDiffs[depth]; int siblingPos = mStackSiblingPos[depth]; int outputIndex = mStackOutputIndex[depth]; int firstChildPos; + mCorrectionState->slideTree(mStackMatchedCount[depth]); + // depth will never be greater than maxDepth because in that case, // needsToTraverseChildrenNodes should be false const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex, - maxDepth, traverseAllNodes, matchCount, inputIndex, diffs, - nextLetters, nextLettersSize, mCorrectionState, &childCount, - &firstChildPos, &traverseAllNodes, &matchCount, &inputIndex, &diffs, + maxDepth, traverseAllNodes, inputIndex, diffs, + mCorrectionState, &childCount, + &firstChildPos, &traverseAllNodes, &inputIndex, &diffs, &siblingPos, &outputIndex); // Update next sibling pos mStackSiblingPos[depth] = siblingPos; @@ -396,15 +391,21 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, ++depth; mStackChildCount[depth] = childCount; mStackTraverseAll[depth] = traverseAllNodes; - mStackMatchCount[depth] = matchCount; mStackInputIndex[depth] = inputIndex; mStackDiffs[depth] = diffs; mStackSiblingPos[depth] = firstChildPos; mStackOutputIndex[depth] = outputIndex; + + int matchedCount; + mCorrectionState->goDownTree(&matchedCount); + mStackMatchedCount[depth] = matchedCount; + } else { + mCorrectionState->slideTree(mStackMatchedCount[depth]); } } else { // Goes to parent sibling node --depth; + mCorrectionState->goUpTree(mStackMatchedCount[depth]); } } } @@ -445,24 +446,13 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c, } -inline void UnigramDictionary::onTerminal(unsigned short int* word, const int depth, - const uint8_t* const root, const uint8_t flags, const int pos, - const int inputIndex, const int matchCount, const int freq, const bool sameLength, - int* nextLetters, const int nextLettersSize, CorrectionState *correctionState) { - const int skipPos = correctionState->getSkipPos(); - - const bool isSameAsTyped = sameLength ? mProximityInfo->sameAsTyped(word, depth + 1) : false; - if (isSameAsTyped) return; - - if (depth >= MIN_SUGGEST_DEPTH) { - const int finalFreq = correctionState->getFinalFreq(inputIndex, depth, matchCount, - freq, sameLength); - if (!isSameAsTyped) - addWord(word, depth + 1, finalFreq); - } - - if (sameLength && depth >= mInputLength && skipPos < 0) { - registerNextLetter(word[mInputLength], nextLetters, nextLettersSize); +inline void UnigramDictionary::onTerminal(unsigned short int* word, const int outputIndex, + const int inputIndex, const int freq, CorrectionState *correctionState) { + if (!mProximityInfo->sameAsTyped(word, outputIndex + 1) && outputIndex >= MIN_SUGGEST_DEPTH) { + const int finalFreq = correctionState->getFinalFreq(inputIndex, outputIndex, freq); + if (finalFreq >= 0) { + addWord(word, outputIndex + 1, finalFreq); + } } } @@ -677,11 +667,11 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs // there aren't any more nodes at this level, it merely returns the address of the first byte after // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any // given level, as output into newCount when traversing this level's parent. -inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth, - const int maxDepth, const bool initialTraverseAllNodes, int matchCount, int inputIndex, - const int initialDiffs, int *nextLetters, const int nextLettersSize, +inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialOutputPos, + const int maxDepth, const bool initialTraverseAllNodes, int inputIndex, + const int initialDiffs, CorrectionState *correctionState, int *newCount, int *newChildrenPosition, - bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs, + bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs, int *nextSiblingPosition, int *newOutputIndex) { const int skipPos = correctionState->getSkipPos(); const int excessivePos = correctionState->getExcessivePos(); @@ -690,7 +680,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in correctionState->checkState(); } int pos = initialPos; - int depth = initialDepth; + int internalOutputPos = initialOutputPos; int traverseAllNodes = initialTraverseAllNodes; int diffs = initialDiffs; @@ -736,15 +726,16 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // This has to be done for each virtual char (this forwards the "inputIndex" which // is the index in the user-inputted chars, as read by proximity chars. - if (excessivePos == depth && inputIndex < mInputLength - 1) ++inputIndex; - if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, depth)) { - mWord[depth] = c; + if (excessivePos == internalOutputPos && inputIndex < mInputLength - 1) { + ++inputIndex; + } + if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, internalOutputPos)) { + mWord[internalOutputPos] = c; if (traverseAllNodes && isTerminal) { // The frequency should be here, because we come here only if this is actually // a terminal node, and we are on its last char. const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchCount, - freq, false, nextLetters, nextLettersSize, mCorrectionState); + onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState); } if (!hasChildren) { // If we don't have children here, that means we finished processing all @@ -784,18 +775,17 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); return false; } - mWord[depth] = c; + mWord[internalOutputPos] = c; // If inputIndex is greater than mInputLength, that means there is no // proximity chars. So, we don't need to check proximity. if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { - ++matchCount; + correctionState->charMatched(); } const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1 || (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2); if (isSameAsUserTypedLength && isTerminal) { const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchCount, - freq, true, nextLetters, nextLettersSize, mCorrectionState); + onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState); } // This character matched the typed character (enough to traverse the node at least) // so we just evaluated it. Now we should evaluate this virtual node's children - that @@ -821,7 +811,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in ++inputIndex; } // Optimization: Prune out words that are too long compared to how much was typed. - if (depth >= maxDepth || diffs > mMaxEditDistance) { + if (internalOutputPos >= maxDepth || diffs > mMaxEditDistance) { // We are giving up parsing this node and its children. Skip the rest of the node, // output the sibling position, and return that we don't want to traverse children. if (!isLastChar) { @@ -838,7 +828,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // contain NOT_A_CHARACTER. c = nextc; // Also, the next char is one "virtual node" depth more than this char. - ++depth; + ++internalOutputPos; } while (NOT_A_CHARACTER != c); // If inputIndex is greater than mInputLength, that means there are no proximity chars. @@ -850,10 +840,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // All the output values that are purely computation by this function are held in local // variables. Output them to the caller. *newTraverseAllNodes = traverseAllNodes; - *newMatchRate = matchCount; *newDiffs = diffs; *newInputIndex = inputIndex; - *newOutputIndex = depth; + *newOutputIndex = internalOutputPos; // Now we finished processing this node, and we want to traverse children. If there are no // children, we can't come here. diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h index f18ed6841..c67eaf6e0 100644 --- a/native/src/unigram_dictionary.h +++ b/native/src/unigram_dictionary.h @@ -87,8 +87,7 @@ private: const int *ycoordinates, const int *codes, const int codesSize, unsigned short *outWords, int *frequencies); void getSuggestionCandidates(const int skipPos, const int excessivePos, - const int transposedPos, int *nextLetters, const int nextLettersSize, - const int maxDepth); + const int transposedPos, const int maxDepth); bool addWord(unsigned short *word, int length, int frequency); void getSplitTwoWordsSuggestion(const int inputLength, CorrectionState *correctionState); void getMissingSpaceWords( @@ -96,17 +95,16 @@ private: void getMistypedSpaceWords( const int inputLength, const int spaceProximityPos, CorrectionState *correctionState); void onTerminal(unsigned short int* word, const int depth, - const uint8_t* const root, const uint8_t flags, const int pos, - const int inputIndex, const int matchWeight, const int freq, const bool sameLength, - int* nextLetters, const int nextLettersSize, CorrectionState *correctionState); + const int inputIndex, const int freq, + CorrectionState *correctionState); bool needsToSkipCurrentNode(const unsigned short c, const int inputIndex, const int skipPos, const int depth); // Process a node by considering proximity, missing and excessive character bool processCurrentNode(const int initialPos, const int initialDepth, - const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex, - const int initialDiffs, int *nextLetters, const int nextLettersSize, + const int maxDepth, const bool initialTraverseAllNodes, int inputIndex, + const int initialDiffs, CorrectionState *correctionState, int *newCount, int *newChildPosition, - bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs, + bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs, int *nextSiblingPosition, int *nextOutputIndex); int getMostFrequentWordLike(const int startInputIndex, const int inputLength, unsigned short *word); @@ -142,14 +140,13 @@ private: unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; int mMaxEditDistance; + int mStackMatchedCount[MAX_WORD_LENGTH_INTERNAL]; int mStackChildCount[MAX_WORD_LENGTH_INTERNAL]; bool mStackTraverseAll[MAX_WORD_LENGTH_INTERNAL]; - int mStackMatchCount[MAX_WORD_LENGTH_INTERNAL]; int mStackInputIndex[MAX_WORD_LENGTH_INTERNAL]; int mStackDiffs[MAX_WORD_LENGTH_INTERNAL]; int mStackSiblingPos[MAX_WORD_LENGTH_INTERNAL]; int mStackOutputIndex[MAX_WORD_LENGTH_INTERNAL]; - int mNextLettersFrequency[NEXT_LETTERS_SIZE]; }; } // namespace latinime