From 612c6e49c03dc49320a0bf141f51e45a8b969d43 Mon Sep 17 00:00:00 2001 From: satok Date: Mon, 1 Aug 2011 19:35:27 +0900 Subject: [PATCH] Move code related to ranking algorithm to the correction state Change-Id: I2d9e2db81cf6597ca4e88d7bc6737ab3b52b34b2 --- native/src/correction_state.cpp | 210 +++++++++++++++++++++++++- native/src/correction_state.h | 32 +++- native/src/unigram_dictionary.cpp | 235 +++++++----------------------- native/src/unigram_dictionary.h | 15 +- 4 files changed, 298 insertions(+), 194 deletions(-) diff --git a/native/src/correction_state.cpp b/native/src/correction_state.cpp index aa5efce40..fba947ed4 100644 --- a/native/src/correction_state.cpp +++ b/native/src/correction_state.cpp @@ -21,18 +21,26 @@ #define LOG_TAG "LatinIME: correction_state.cpp" #include "correction_state.h" +#include "proximity_info.h" namespace latinime { -CorrectionState::CorrectionState() { +CorrectionState::CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier) + : TYPED_LETTER_MULTIPLIER(typedLetterMultiplier), FULL_WORD_MULTIPLIER(fullWordMultiplier) { } -void CorrectionState::setCorrectionParams(const ProximityInfo *pi, const int inputLength, - const int skipPos, const int excessivePos, const int transposedPos) { +void CorrectionState::initCorrectionState(const ProximityInfo *pi, const int inputLength) { mProximityInfo = pi; + mInputLength = inputLength; +} + +void CorrectionState::setCorrectionParams(const int skipPos, const int excessivePos, + const int transposedPos, const int spaceProximityPos, const int missingSpacePos) { mSkipPos = skipPos; mExcessivePos = excessivePos; mTransposedPos = transposedPos; + mSpaceProximityPos = spaceProximityPos; + mMissingSpacePos = missingSpacePos; } void CorrectionState::checkState() { @@ -46,7 +54,203 @@ void CorrectionState::checkState() { } } +int CorrectionState::getFreqForSplitTwoWords(const int firstFreq, const int secondFreq) { + return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this); +} + +int CorrectionState::getFinalFreq(const int inputIndex, const int depth, const int matchWeight, + const int freq, const bool sameLength) { + return CorrectionState::RankingAlgorithm::calculateFinalFreq(inputIndex, depth, matchWeight, + freq, sameLength, this); +} + CorrectionState::~CorrectionState() { } +///////////////////////// +// static inline utils // +///////////////////////// + +static const int TWO_31ST_DIV_255 = S_INT_MAX / 255; +static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) { + return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX); +} + +static const int TWO_31ST_DIV_2 = S_INT_MAX / 2; +inline static void multiplyIntCapped(const int multiplier, int *base) { + const int temp = *base; + if (temp != S_INT_MAX) { + // Branch if multiplier == 2 for the optimization + if (multiplier == 2) { + *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX; + } else { + const int tempRetval = temp * multiplier; + *base = tempRetval >= temp ? tempRetval : S_INT_MAX; + } + } +} + +inline static int powerIntCapped(const int base, const int n) { + if (n == 0) return 1; + if (base == 2) { + return n < 31 ? 1 << n : S_INT_MAX; + } else { + int ret = base; + for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret); + return ret; + } +} + +inline static void multiplyRate(const int rate, int *freq) { + if (*freq != S_INT_MAX) { + if (*freq > 1000000) { + *freq /= 100; + multiplyIntCapped(rate, freq); + } else { + multiplyIntCapped(rate, freq); + *freq /= 100; + } + } +} + +////////////////////// +// RankingAlgorithm // +////////////////////// + +int CorrectionState::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int depth, + const int matchCount, const int freq, const bool sameLength, + const CorrectionState* correctionState) { + const int skipPos = correctionState->getSkipPos(); + const int excessivePos = correctionState->getExcessivePos(); + const int transposedPos = correctionState->getTransposedPos(); + const int inputLength = correctionState->mInputLength; + const int typedLetterMultiplier = correctionState->TYPED_LETTER_MULTIPLIER; + const int fullWordMultiplier = correctionState->FULL_WORD_MULTIPLIER; + const ProximityInfo *proximityInfo = correctionState->mProximityInfo; + const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); + + // TODO: Demote by edit distance + int finalFreq = freq * matchWeight; + if (skipPos >= 0) { + if (inputLength >= 2) { + const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE + * (10 * inputLength - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X) + / (10 * inputLength + - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X + 10); + if (DEBUG_DICT_FULL) { + LOGI("Demotion rate for missing character is %d.", demotionRate); + } + multiplyRate(demotionRate, &finalFreq); + } else { + finalFreq = 0; + } + } + if (transposedPos >= 0) multiplyRate( + WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE, &finalFreq); + if (excessivePos >= 0) { + multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_DEMOTION_RATE, &finalFreq); + if (!proximityInfo->existsAdjacentProximityChars(inputIndex)) { + // If an excessive character is not adjacent to the left char or the right char, + // we will demote this word. + multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_OUT_OF_PROXIMITY_DEMOTION_RATE, &finalFreq); + } + } + int lengthFreq = typedLetterMultiplier; + multiplyIntCapped(powerIntCapped(typedLetterMultiplier, depth), &lengthFreq); + if (lengthFreq == matchWeight) { + // Full exact match + if (depth > 1) { + if (DEBUG_DICT) { + LOGI("Found full matched word."); + } + multiplyRate(FULL_MATCHED_WORDS_PROMOTION_RATE, &finalFreq); + } + if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0) { + finalFreq = capped255MultForFullMatchAccentsOrCapitalizationDifference(finalFreq); + } + } else if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0 && depth > 0) { + // A word with proximity corrections + if (DEBUG_DICT) { + LOGI("Found one proximity correction."); + } + multiplyIntCapped(typedLetterMultiplier, &finalFreq); + multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); + } + if (DEBUG_DICT) { + LOGI("calc: %d, %d", depth, sameLength); + } + if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq); + return finalFreq; +} + +int CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords( + const int firstFreq, const int secondFreq, const CorrectionState* correctionState) { + const int spaceProximityPos = correctionState->mSpaceProximityPos; + const int missingSpacePos = correctionState->mMissingSpacePos; + if (DEBUG_DICT) { + int inputCount = 0; + if (spaceProximityPos >= 0) ++inputCount; + if (missingSpacePos >= 0) ++inputCount; + assert(inputCount <= 1); + } + const bool isSpaceProximity = spaceProximityPos >= 0; + const int inputLength = correctionState->mInputLength; + const int firstWordLength = isSpaceProximity ? spaceProximityPos : missingSpacePos; + const int secondWordLength = isSpaceProximity + ? (inputLength - spaceProximityPos - 1) + : (inputLength - missingSpacePos); + const int typedLetterMultiplier = correctionState->TYPED_LETTER_MULTIPLIER; + + if (firstWordLength == 0 || secondWordLength == 0) { + return 0; + } + const int firstDemotionRate = 100 - 100 / (firstWordLength + 1); + int tempFirstFreq = firstFreq; + multiplyRate(firstDemotionRate, &tempFirstFreq); + + const int secondDemotionRate = 100 - 100 / (secondWordLength + 1); + int tempSecondFreq = secondFreq; + multiplyRate(secondDemotionRate, &tempSecondFreq); + + const int totalLength = firstWordLength + secondWordLength; + + // Promote pairFreq with multiplying by 2, because the word length is the same as the typed + // length. + int totalFreq = tempFirstFreq + tempSecondFreq; + + // This is a workaround to try offsetting the not-enough-demotion which will be done in + // calcNormalizedScore in Utils.java. + // In calcNormalizedScore the score will be demoted by (1 - 1 / length) + // but we demoted only (1 - 1 / (length + 1)) so we will additionally adjust freq by + // (1 - 1 / length) / (1 - 1 / (length + 1)) = (1 - 1 / (length * length)) + const int normalizedScoreNotEnoughDemotionAdjustment = 100 - 100 / (totalLength * totalLength); + multiplyRate(normalizedScoreNotEnoughDemotionAdjustment, &totalFreq); + + // At this moment, totalFreq is calculated by the following formula: + // (firstFreq * (1 - 1 / (firstWordLength + 1)) + secondFreq * (1 - 1 / (secondWordLength + 1))) + // * (1 - 1 / totalLength) / (1 - 1 / (totalLength + 1)) + + multiplyIntCapped(powerIntCapped(typedLetterMultiplier, totalLength), &totalFreq); + + // This is another workaround to offset the demotion which will be done in + // calcNormalizedScore in Utils.java. + // In calcNormalizedScore the score will be demoted by (1 - 1 / length) so we have to promote + // the same amount because we already have adjusted the synthetic freq of this "missing or + // mistyped space" suggestion candidate above in this method. + const int normalizedScoreDemotionRateOffset = (100 + 100 / totalLength); + multiplyRate(normalizedScoreDemotionRateOffset, &totalFreq); + + if (isSpaceProximity) { + // A word pair with one space proximity correction + if (DEBUG_DICT) { + LOGI("Found a word pair with space proximity correction."); + } + multiplyIntCapped(typedLetterMultiplier, &totalFreq); + multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &totalFreq); + } + + multiplyRate(WORDS_WITH_MISSING_SPACE_CHARACTER_DEMOTION_RATE, &totalFreq); + return totalFreq; +} + } // namespace latinime diff --git a/native/src/correction_state.h b/native/src/correction_state.h index 5b7392590..e03b2a17c 100644 --- a/native/src/correction_state.h +++ b/native/src/correction_state.h @@ -26,10 +26,12 @@ namespace latinime { class ProximityInfo; class CorrectionState { + public: - CorrectionState(); - void setCorrectionParams(const ProximityInfo *pi, const int inputLength, const int skipPos, - const int excessivePos, const int transposedPos); + CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier); + void initCorrectionState(const ProximityInfo *pi, const int inputLength); + void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, + const int spaceProximityPos, const int missingSpacePos); void checkState(); virtual ~CorrectionState(); int getSkipPos() const { @@ -41,12 +43,36 @@ public: int getTransposedPos() const { return mTransposedPos; } + int getSpaceProximityPos() const { + return mSpaceProximityPos; + } + int getMissingSpacePos() const { + return mMissingSpacePos; + } + int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq); + int getFinalFreq(const int inputIndex, const int depth, const int matchWeight, const int freq, + const bool sameLength); + private: + + const int TYPED_LETTER_MULTIPLIER; + const int FULL_WORD_MULTIPLIER; const ProximityInfo *mProximityInfo; int mInputLength; int mSkipPos; int mExcessivePos; int mTransposedPos; + int mSpaceProximityPos; + int mMissingSpacePos; + + class RankingAlgorithm { + public: + static int calculateFinalFreq(const int inputIndex, const int depth, + const int matchCount, const int freq, const bool sameLength, + const CorrectionState* correctionState); + static int calcFreqForSplitTwoWords(const int firstFreq, const int secondFreq, + const CorrectionState* correctionState); + }; }; } // namespace latinime #endif // LATINIME_CORRECTION_INFO_H diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index f0bb384fb..eb28538f1 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -48,7 +48,7 @@ UnigramDictionary::UnigramDictionary(const uint8_t* const streamStart, int typed if (DEBUG_DICT) { LOGI("UnigramDictionary - constructor"); } - mCorrectionState = new CorrectionState(); + mCorrectionState = new CorrectionState(typedLetterMultiplier, fullWordMultiplier); } UnigramDictionary::~UnigramDictionary() { @@ -187,6 +187,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, PROF_START(0); initSuggestions( proximityInfo, xcoordinates, ycoordinates, codes, codesSize, outWords, frequencies); + mCorrectionState->initCorrectionState(mProximityInfo, mInputLength); if (DEBUG_DICT) assert(codesSize == mInputLength); const int MAX_DEPTH = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH); @@ -242,7 +243,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest missing space characters %d", i); } - getMissingSpaceWords(mInputLength, i); + getMissingSpaceWords(mInputLength, i, mCorrectionState); } } PROF_END(5); @@ -261,7 +262,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, i, x, y, proximityInfo->hasSpaceProximity(x, y)); } if (proximityInfo->hasSpaceProximity(x, y)) { - getMistypedSpaceWords(mInputLength, i); + getMistypedSpaceWords(mInputLength, i, mCorrectionState); } } } @@ -355,8 +356,8 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, assert(excessivePos < mInputLength); assert(missingPos < mInputLength); } - mCorrectionState->setCorrectionParams(mProximityInfo, mInputLength, skipPos, excessivePos, - transposedPos); + mCorrectionState->setCorrectionParams(skipPos, excessivePos, transposedPos, + -1 /* spaceProximityPos */, -1 /* missingSpacePos */); int rootPosition = ROOT_POS; // Get the number of children of root, then increment the position int childCount = Dictionary::getCount(DICT_ROOT, &rootPosition); @@ -364,7 +365,7 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, mStackChildCount[0] = childCount; mStackTraverseAll[0] = (mInputLength <= 0); - mStackNodeFreq[0] = 1; + mStackMatchCount[0] = 0; mStackInputIndex[0] = 0; mStackDiffs[0] = 0; mStackSiblingPos[0] = rootPosition; @@ -375,7 +376,7 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, if (mStackChildCount[depth] > 0) { --mStackChildCount[depth]; bool traverseAllNodes = mStackTraverseAll[depth]; - int matchWeight = mStackNodeFreq[depth]; + int matchCount = mStackMatchCount[depth]; int inputIndex = mStackInputIndex[depth]; int diffs = mStackDiffs[depth]; int siblingPos = mStackSiblingPos[depth]; @@ -384,9 +385,9 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, // depth will never be greater than maxDepth because in that case, // needsToTraverseChildrenNodes should be false const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex, - maxDepth, traverseAllNodes, matchWeight, inputIndex, diffs, + maxDepth, traverseAllNodes, matchCount, inputIndex, diffs, nextLetters, nextLettersSize, mCorrectionState, &childCount, - &firstChildPos, &traverseAllNodes, &matchWeight, &inputIndex, &diffs, + &firstChildPos, &traverseAllNodes, &matchCount, &inputIndex, &diffs, &siblingPos, &outputIndex); // Update next sibling pos mStackSiblingPos[depth] = siblingPos; @@ -395,7 +396,7 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, ++depth; mStackChildCount[depth] = childCount; mStackTraverseAll[depth] = traverseAllNodes; - mStackNodeFreq[depth] = matchWeight; + mStackMatchCount[depth] = matchCount; mStackInputIndex[depth] = inputIndex; mStackDiffs[depth] = diffs; mStackSiblingPos[depth] = firstChildPos; @@ -408,11 +409,6 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, } } -static const int TWO_31ST_DIV_255 = S_INT_MAX / 255; -static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) { - return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX); -} - static const int TWO_31ST_DIV_2 = S_INT_MAX / 2; inline static void multiplyIntCapped(const int multiplier, int *base) { const int temp = *base; @@ -427,153 +423,18 @@ inline static void multiplyIntCapped(const int multiplier, int *base) { } } -inline static int powerIntCapped(const int base, const int n) { - if (base == 2) { - return n < 31 ? 1 << n : S_INT_MAX; - } else { - int ret = base; - for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret); - return ret; - } +void UnigramDictionary::getMissingSpaceWords( + const int inputLength, const int missingSpacePos, CorrectionState *correctionState) { + correctionState->setCorrectionParams(-1 /* skipPos */, -1 /* excessivePos */, + -1 /* transposedPos */, -1 /* spaceProximityPos */, missingSpacePos); + getSplitTwoWordsSuggestion(inputLength, correctionState); } -inline static void multiplyRate(const int rate, int *freq) { - if (*freq != S_INT_MAX) { - if (*freq > 1000000) { - *freq /= 100; - multiplyIntCapped(rate, freq); - } else { - multiplyIntCapped(rate, freq); - *freq /= 100; - } - } -} - -inline static int calcFreqForSplitTwoWords( - const int typedLetterMultiplier, const int firstWordLength, const int secondWordLength, - const int firstFreq, const int secondFreq, const bool isSpaceProximity) { - if (firstWordLength == 0 || secondWordLength == 0) { - return 0; - } - const int firstDemotionRate = 100 - 100 / (firstWordLength + 1); - int tempFirstFreq = firstFreq; - multiplyRate(firstDemotionRate, &tempFirstFreq); - - const int secondDemotionRate = 100 - 100 / (secondWordLength + 1); - int tempSecondFreq = secondFreq; - multiplyRate(secondDemotionRate, &tempSecondFreq); - - const int totalLength = firstWordLength + secondWordLength; - - // Promote pairFreq with multiplying by 2, because the word length is the same as the typed - // length. - int totalFreq = tempFirstFreq + tempSecondFreq; - - // This is a workaround to try offsetting the not-enough-demotion which will be done in - // calcNormalizedScore in Utils.java. - // In calcNormalizedScore the score will be demoted by (1 - 1 / length) - // but we demoted only (1 - 1 / (length + 1)) so we will additionally adjust freq by - // (1 - 1 / length) / (1 - 1 / (length + 1)) = (1 - 1 / (length * length)) - const int normalizedScoreNotEnoughDemotionAdjustment = 100 - 100 / (totalLength * totalLength); - multiplyRate(normalizedScoreNotEnoughDemotionAdjustment, &totalFreq); - - // At this moment, totalFreq is calculated by the following formula: - // (firstFreq * (1 - 1 / (firstWordLength + 1)) + secondFreq * (1 - 1 / (secondWordLength + 1))) - // * (1 - 1 / totalLength) / (1 - 1 / (totalLength + 1)) - - multiplyIntCapped(powerIntCapped(typedLetterMultiplier, totalLength), &totalFreq); - - // This is another workaround to offset the demotion which will be done in - // calcNormalizedScore in Utils.java. - // In calcNormalizedScore the score will be demoted by (1 - 1 / length) so we have to promote - // the same amount because we already have adjusted the synthetic freq of this "missing or - // mistyped space" suggestion candidate above in this method. - const int normalizedScoreDemotionRateOffset = (100 + 100 / totalLength); - multiplyRate(normalizedScoreDemotionRateOffset, &totalFreq); - - if (isSpaceProximity) { - // A word pair with one space proximity correction - if (DEBUG_DICT) { - LOGI("Found a word pair with space proximity correction."); - } - multiplyIntCapped(typedLetterMultiplier, &totalFreq); - multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &totalFreq); - } - - multiplyRate(WORDS_WITH_MISSING_SPACE_CHARACTER_DEMOTION_RATE, &totalFreq); - return totalFreq; -} - -bool UnigramDictionary::getMissingSpaceWords(const int inputLength, const int missingSpacePos) { - return getSplitTwoWordsSuggestion( - inputLength, 0, missingSpacePos, missingSpacePos, inputLength - missingSpacePos, false); -} - -bool UnigramDictionary::getMistypedSpaceWords(const int inputLength, const int spaceProximityPos) { - return getSplitTwoWordsSuggestion( - inputLength, 0, spaceProximityPos, spaceProximityPos + 1, - inputLength - spaceProximityPos - 1, true); -} - -inline int UnigramDictionary::calculateFinalFreq(const int inputIndex, const int depth, - const int matchWeight, const int freq, const bool sameLength, - CorrectionState *correctionState) const { - const int skipPos = correctionState->getSkipPos(); - const int excessivePos = correctionState->getExcessivePos(); - const int transposedPos = correctionState->getTransposedPos(); - - // TODO: Demote by edit distance - int finalFreq = freq * matchWeight; - if (skipPos >= 0) { - if (mInputLength >= 2) { - const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE - * (10 * mInputLength - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X) - / (10 * mInputLength - - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X + 10); - if (DEBUG_DICT_FULL) { - LOGI("Demotion rate for missing character is %d.", demotionRate); - } - multiplyRate(demotionRate, &finalFreq); - } else { - finalFreq = 0; - } - } - if (transposedPos >= 0) multiplyRate( - WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE, &finalFreq); - if (excessivePos >= 0) { - multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_DEMOTION_RATE, &finalFreq); - if (!mProximityInfo->existsAdjacentProximityChars(inputIndex)) { - // If an excessive character is not adjacent to the left char or the right char, - // we will demote this word. - multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_OUT_OF_PROXIMITY_DEMOTION_RATE, &finalFreq); - } - } - int lengthFreq = TYPED_LETTER_MULTIPLIER; - multiplyIntCapped(powerIntCapped(TYPED_LETTER_MULTIPLIER, depth), &lengthFreq); - if (lengthFreq == matchWeight) { - // Full exact match - if (depth > 1) { - if (DEBUG_DICT) { - LOGI("Found full matched word."); - } - multiplyRate(FULL_MATCHED_WORDS_PROMOTION_RATE, &finalFreq); - } - if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0) { - finalFreq = capped255MultForFullMatchAccentsOrCapitalizationDifference(finalFreq); - } - } else if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0 && depth > 0) { - // A word with proximity corrections - if (DEBUG_DICT) { - LOGI("Found one proximity correction."); - } - multiplyIntCapped(TYPED_LETTER_MULTIPLIER, &finalFreq); - multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); - } - if (DEBUG_DICT) { - LOGI("calc: %d, %d", depth, sameLength); - } - if (sameLength) multiplyIntCapped(FULL_WORD_MULTIPLIER, &finalFreq); - return finalFreq; +void UnigramDictionary::getMistypedSpaceWords( + const int inputLength, const int spaceProximityPos, CorrectionState *correctionState) { + correctionState->setCorrectionParams(-1 /* skipPos */, -1 /* excessivePos */, + -1 /* transposedPos */, spaceProximityPos, -1 /* missingSpacePos */); + getSplitTwoWordsSuggestion(inputLength, correctionState); } inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c, @@ -586,7 +447,7 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c, inline void UnigramDictionary::onTerminal(unsigned short int* word, const int depth, const uint8_t* const root, const uint8_t flags, const int pos, - const int inputIndex, const int matchWeight, const int freq, const bool sameLength, + const int inputIndex, const int matchCount, const int freq, const bool sameLength, int* nextLetters, const int nextLettersSize, CorrectionState *correctionState) { const int skipPos = correctionState->getSkipPos(); @@ -594,8 +455,8 @@ inline void UnigramDictionary::onTerminal(unsigned short int* word, const int de if (isSameAsTyped) return; if (depth >= MIN_SUGGEST_DEPTH) { - const int finalFreq = calculateFinalFreq(inputIndex, depth, matchWeight, - freq, sameLength, correctionState); + const int finalFreq = correctionState->getFinalFreq(inputIndex, depth, matchCount, + freq, sameLength); if (!isSameAsTyped) addWord(word, depth + 1, finalFreq); } @@ -605,13 +466,29 @@ inline void UnigramDictionary::onTerminal(unsigned short int* word, const int de } } -bool UnigramDictionary::getSplitTwoWordsSuggestion(const int inputLength, - const int firstWordStartPos, const int firstWordLength, const int secondWordStartPos, - const int secondWordLength, const bool isSpaceProximity) { - if (inputLength >= MAX_WORD_LENGTH) return false; +void UnigramDictionary::getSplitTwoWordsSuggestion( + const int inputLength, CorrectionState* correctionState) { + const int spaceProximityPos = correctionState->getSpaceProximityPos(); + const int missingSpacePos = correctionState->getMissingSpacePos(); + if (DEBUG_DICT) { + int inputCount = 0; + if (spaceProximityPos >= 0) ++inputCount; + if (missingSpacePos >= 0) ++inputCount; + assert(inputCount <= 1); + } + const bool isSpaceProximity = spaceProximityPos >= 0; + const int firstWordStartPos = 0; + const int secondWordStartPos = isSpaceProximity ? (spaceProximityPos + 1) : missingSpacePos; + const int firstWordLength = isSpaceProximity ? spaceProximityPos : missingSpacePos; + const int secondWordLength = isSpaceProximity + ? (inputLength - spaceProximityPos - 1) + : (inputLength - missingSpacePos); + + if (inputLength >= MAX_WORD_LENGTH) return; if (0 >= firstWordLength || 0 >= secondWordLength || firstWordStartPos >= secondWordStartPos || firstWordStartPos < 0 || secondWordStartPos + secondWordLength > inputLength) - return false; + return; + const int newWordLength = firstWordLength + secondWordLength + 1; // Allocating variable length array on stack unsigned short word[newWordLength]; @@ -619,7 +496,7 @@ bool UnigramDictionary::getSplitTwoWordsSuggestion(const int inputLength, if (DEBUG_DICT) { LOGI("First freq: %d", firstFreq); } - if (firstFreq <= 0) return false; + if (firstFreq <= 0) return; for (int i = 0; i < firstWordLength; ++i) { word[i] = mWord[i]; @@ -629,21 +506,19 @@ bool UnigramDictionary::getSplitTwoWordsSuggestion(const int inputLength, if (DEBUG_DICT) { LOGI("Second freq: %d", secondFreq); } - if (secondFreq <= 0) return false; + if (secondFreq <= 0) return; word[firstWordLength] = SPACE; for (int i = (firstWordLength + 1); i < newWordLength; ++i) { word[i] = mWord[i - firstWordLength - 1]; } - int pairFreq = calcFreqForSplitTwoWords(TYPED_LETTER_MULTIPLIER, firstWordLength, - secondWordLength, firstFreq, secondFreq, isSpaceProximity); + const int pairFreq = mCorrectionState->getFreqForSplitTwoWords(firstFreq, secondFreq); if (DEBUG_DICT) { - LOGI("Split two words: %d, %d, %d, %d, %d", firstFreq, secondFreq, pairFreq, inputLength, - TYPED_LETTER_MULTIPLIER); + LOGI("Split two words: %d, %d, %d, %d", firstFreq, secondFreq, pairFreq, inputLength); } addWord(word, newWordLength, pairFreq); - return true; + return; } // Wrapper for getMostFrequentWordLikeInner, which matches it to the previous @@ -803,7 +678,7 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any // given level, as output into newCount when traversing this level's parent. inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth, - const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex, + const int maxDepth, const bool initialTraverseAllNodes, int matchCount, int inputIndex, const int initialDiffs, int *nextLetters, const int nextLettersSize, CorrectionState *correctionState, int *newCount, int *newChildrenPosition, bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs, @@ -868,7 +743,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // The frequency should be here, because we come here only if this is actually // a terminal node, and we are on its last char. const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, + onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchCount, freq, false, nextLetters, nextLettersSize, mCorrectionState); } if (!hasChildren) { @@ -913,13 +788,13 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // If inputIndex is greater than mInputLength, that means there is no // proximity chars. So, we don't need to check proximity. if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { - multiplyIntCapped(TYPED_LETTER_MULTIPLIER, &matchWeight); + ++matchCount; } const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1 || (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2); if (isSameAsUserTypedLength && isTerminal) { const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, + onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchCount, freq, true, nextLetters, nextLettersSize, mCorrectionState); } // This character matched the typed character (enough to traverse the node at least) @@ -975,7 +850,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // All the output values that are purely computation by this function are held in local // variables. Output them to the caller. *newTraverseAllNodes = traverseAllNodes; - *newMatchRate = matchWeight; + *newMatchRate = matchCount; *newDiffs = diffs; *newInputIndex = inputIndex; *newOutputIndex = depth; diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h index 41e381860..f18ed6841 100644 --- a/native/src/unigram_dictionary.h +++ b/native/src/unigram_dictionary.h @@ -74,6 +74,7 @@ public: virtual ~UnigramDictionary(); private: + void getWordSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, const int *ycoordinates, const int *codes, const int codesSize, unsigned short *outWords, int *frequencies); @@ -89,13 +90,11 @@ private: const int transposedPos, int *nextLetters, const int nextLettersSize, const int maxDepth); bool addWord(unsigned short *word, int length, int frequency); - bool getSplitTwoWordsSuggestion(const int inputLength, - const int firstWordStartPos, const int firstWordLength, - const int secondWordStartPos, const int secondWordLength, const bool isSpaceProximity); - bool getMissingSpaceWords(const int inputLength, const int missingSpacePos); - bool getMistypedSpaceWords(const int inputLength, const int spaceProximityPos); - int calculateFinalFreq(const int inputIndex, const int depth, const int snr, - const int freq, const bool sameLength, CorrectionState *correctionState) const; + void getSplitTwoWordsSuggestion(const int inputLength, CorrectionState *correctionState); + void getMissingSpaceWords( + const int inputLength, const int missingSpacePos, CorrectionState *correctionState); + void getMistypedSpaceWords( + const int inputLength, const int spaceProximityPos, CorrectionState *correctionState); void onTerminal(unsigned short int* word, const int depth, const uint8_t* const root, const uint8_t flags, const int pos, const int inputIndex, const int matchWeight, const int freq, const bool sameLength, @@ -145,7 +144,7 @@ private: int mStackChildCount[MAX_WORD_LENGTH_INTERNAL]; bool mStackTraverseAll[MAX_WORD_LENGTH_INTERNAL]; - int mStackNodeFreq[MAX_WORD_LENGTH_INTERNAL]; + int mStackMatchCount[MAX_WORD_LENGTH_INTERNAL]; int mStackInputIndex[MAX_WORD_LENGTH_INTERNAL]; int mStackDiffs[MAX_WORD_LENGTH_INTERNAL]; int mStackSiblingPos[MAX_WORD_LENGTH_INTERNAL];