diff --git a/native/src/correction.cpp b/native/src/correction.cpp index a931a61fb..654d4715f 100644 --- a/native/src/correction.cpp +++ b/native/src/correction.cpp @@ -49,12 +49,11 @@ void Correction::initCorrection(const ProximityInfo *pi, const int inputLength, mInputLength = inputLength; mMaxDepth = maxDepth; mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2; - mSkippedOutputIndex = -1; } void Correction::initCorrectionState( const int rootPos, const int childCount, const bool traverseAll) { - mCorrectionStates[0].init(rootPos, childCount, traverseAll); + latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll); } void Correction::setCorrectionParams(const int skipPos, const int excessivePos, @@ -88,6 +87,12 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) { return -1; } + + // TODO: Remove this + if (mSkipPos >= 0 && mSkippedCount <= 0) { + return -1; + } + *word = mWord; const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2) : (mInputLength == inputIndex + 1); @@ -103,8 +108,11 @@ bool Correction::initProcessState(const int outputIndex) { --(mCorrectionStates[outputIndex].mChildCount); mMatchedCharCount = mCorrectionStates[outputIndex].mMatchedCount; mInputIndex = mCorrectionStates[outputIndex].mInputIndex; - mTraverseAllNodes = mCorrectionStates[outputIndex].mTraverseAll; + mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes; mDiffs = mCorrectionStates[outputIndex].mDiffs; + mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount; + mSkipping = false; + mMatching = false; return true; } @@ -131,8 +139,8 @@ int Correction::getInputIndex() { } // TODO: remove -bool Correction::needsToTraverseAll() { - return mTraverseAllNodes; +bool Correction::needsToTraverseAllNodes() { + return mNeedsToTraverseAllNodes; } void Correction::incrementInputIndex() { @@ -146,12 +154,15 @@ void Correction::incrementOutputIndex() { mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; mCorrectionStates[mOutputIndex].mMatchedCount = mMatchedCharCount; mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; - mCorrectionStates[mOutputIndex].mTraverseAll = mTraverseAllNodes; + mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; mCorrectionStates[mOutputIndex].mDiffs = mDiffs; + mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; + mCorrectionStates[mOutputIndex].mSkipping = mSkipping; + mCorrectionStates[mOutputIndex].mMatching = mMatching; } -void Correction::startTraverseAll() { - mTraverseAllNodes = true; +void Correction::startToTraverseAllNodes() { + mNeedsToTraverseAllNodes = true; } bool Correction::needsToPrune() const { @@ -162,7 +173,7 @@ bool Correction::needsToPrune() const { Correction::CorrectionType Correction::processSkipChar( const int32_t c, const bool isTerminal) { mWord[mOutputIndex] = c; - if (needsToTraverseAll() && isTerminal) { + if (needsToTraverseAllNodes() && isTerminal) { mTerminalInputIndex = mInputIndex; mTerminalOutputIndex = mOutputIndex; incrementOutputIndex(); @@ -185,9 +196,10 @@ Correction::CorrectionType Correction::processCharAndCalcState( bool skip = false; if (mSkipPos >= 0) { skip = mSkipPos == mOutputIndex; + mSkipping = true; } - if (mTraverseAllNodes || isQuote(c)) { + if (mNeedsToTraverseAllNodes || isQuote(c)) { return processSkipChar(c, isTerminal); } else { int inputIndexForProximity = mInputIndex; @@ -210,25 +222,23 @@ Correction::CorrectionType Correction::processCharAndCalcState( if (unrelated) { if (skip) { // Skip this letter and continue deeper - mSkippedOutputIndex = mOutputIndex; + ++mSkippedCount; return processSkipChar(c, isTerminal); } else { return UNRELATED; } } - // No need to skip. Finish traversing and increment skipPos. - // TODO: Remove this? + // TODO: remove after allowing combination errors if (skip) { - mWord[mOutputIndex] = c; - incrementOutputIndex(); - return TRAVERSE_ALL_NOT_ON_TERMINAL; + return UNRELATED; } mWord[mOutputIndex] = c; // If inputIndex is greater than mInputLength, that means there is no // proximity chars. So, we don't need to check proximity. if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { + mMatching = true; charMatched(); } @@ -247,7 +257,7 @@ Correction::CorrectionType Correction::processCharAndCalcState( } // Start traversing all nodes after the index exceeds the user typed length if (isSameAsUserTypedLength) { - startTraverseAll(); + startToTraverseAllNodes(); } // Finally, we are ready to go to the next character, the next "virtual node". @@ -317,6 +327,7 @@ inline static void multiplyRate(const int rate, int *freq) { // RankingAlgorithm // ////////////////////// +/* static */ int Correction::RankingAlgorithm::calculateFinalFreq( const int inputIndex, const int outputIndex, const int matchCount, const int freq, const bool sameLength, @@ -329,6 +340,8 @@ int Correction::RankingAlgorithm::calculateFinalFreq( const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER; const ProximityInfo *proximityInfo = correction->mProximityInfo; const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); + const unsigned short* word = correction->mWord; + const int skippedCount = correction->mSkippedCount; // TODO: Demote by edit distance int finalFreq = freq * matchWeight; @@ -382,9 +395,30 @@ int Correction::RankingAlgorithm::calculateFinalFreq( LOGI("calc: %d, %d", outputIndex, sameLength); } if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq); + + // TODO: check excessive count and transposed count + /* + If the last character of the user input word is the same as the next character + of the output word, and also all of characters of the user input are matched + to the output word, we'll promote that word a bit because + that word can be considered the combination of skipped and matched characters. + This means that the 'sm' pattern wins over the 'ma' pattern. + e.g.) + shel -> shell [mmmma] or [mmmsm] + hel -> hello [mmmaa] or [mmsma] + m ... matching + s ... skipping + a ... traversing all + */ + if (matchCount == inputLength && matchCount >= 2 && skippedCount == 0 + && word[matchCount] == word[matchCount - 1]) { + multiplyRate(WORDS_WITH_MATCH_SKIP_PROMOTION_RATE, &finalFreq); + } + return finalFreq; } +/* static */ int Correction::RankingAlgorithm::calcFreqForSplitTwoWords( const int firstFreq, const int secondFreq, const Correction* correction) { const int spaceProximityPos = correction->mSpaceProximityPos; diff --git a/native/src/correction.h b/native/src/correction.h index 590d62f62..62fe38696 100644 --- a/native/src/correction.h +++ b/native/src/correction.h @@ -52,7 +52,6 @@ public: bool *traverseAllNodes, int *diffs); int getOutputIndex(); int getInputIndex(); - bool needsToTraverseAll(); virtual ~Correction(); int getSpaceProximityPos() const { @@ -101,45 +100,46 @@ public: return mCorrectionStates[index].mParentIndex; } private: - void charMatched(); - void incrementInputIndex(); - void incrementOutputIndex(); - void startTraverseAll(); + inline void charMatched(); + inline void incrementInputIndex(); + inline void incrementOutputIndex(); + inline bool needsToTraverseAllNodes(); + inline void startToTraverseAllNodes(); + inline bool isQuote(const unsigned short c); + inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal); // TODO: remove - - void incrementDiffs() { + inline void incrementDiffs() { ++mDiffs; } const int TYPED_LETTER_MULTIPLIER; const int FULL_WORD_MULTIPLIER; - const ProximityInfo *mProximityInfo; int mMaxEditDistance; int mMaxDepth; int mInputLength; int mSkipPos; - int mSkippedOutputIndex; int mExcessivePos; int mTransposedPos; int mSpaceProximityPos; int mMissingSpacePos; - - int mMatchedCharCount; - int mInputIndex; - int mOutputIndex; int mTerminalInputIndex; int mTerminalOutputIndex; - int mDiffs; - bool mTraverseAllNodes; unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL]; - inline bool isQuote(const unsigned short c); - inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal); + // The following member variables are being used as cache values of the correction state. + int mOutputIndex; + int mInputIndex; + int mDiffs; + int mMatchedCharCount; + int mSkippedCount; + bool mNeedsToTraverseAllNodes; + bool mMatching; + bool mSkipping; class RankingAlgorithm { public: diff --git a/native/src/correction_state.h b/native/src/correction_state.h index 1fe02b853..731222696 100644 --- a/native/src/correction_state.h +++ b/native/src/correction_state.h @@ -23,32 +23,33 @@ namespace latinime { -class CorrectionState { -public: +struct CorrectionState { int mParentIndex; - int mMatchedCount; - int mChildCount; - int mInputIndex; - int mDiffs; int mSiblingPos; - bool mTraverseAll; + uint16_t mChildCount; + uint8_t mInputIndex; + uint8_t mDiffs; + uint8_t mMatchedCount; + uint8_t mSkippedCount; + bool mMatching; + bool mSkipping; + bool mNeedsToTraverseAllNodes; - inline void init(const int rootPos, const int childCount, const bool traverseAll) { - set(-1, 0, childCount, 0, 0, rootPos, traverseAll); - } - -private: - inline void set(const int parentIndex, const int matchedCount, const int childCount, - const int inputIndex, const int diffs, const int siblingPos, - const bool traverseAll) { - mParentIndex = parentIndex; - mMatchedCount = matchedCount; - mChildCount = childCount; - mInputIndex = inputIndex; - mDiffs = diffs; - mSiblingPos = siblingPos; - mTraverseAll = traverseAll; - } }; + +inline static void initCorrectionState(CorrectionState *state, const int rootPos, + const uint16_t childCount, const bool traverseAll) { + state->mParentIndex = -1; + state->mChildCount = childCount; + state->mInputIndex = 0; + state->mDiffs = 0; + state->mSiblingPos = rootPos; + state->mMatchedCount = 0; + state->mSkippedCount = 0; + state->mMatching = false; + state->mSkipping = false; + state->mNeedsToTraverseAllNodes = traverseAll; +} + } // namespace latinime #endif // LATINIME_CORRECTION_STATE_H diff --git a/native/src/defines.h b/native/src/defines.h index 5a5d3ee0c..c1838d341 100644 --- a/native/src/defines.h +++ b/native/src/defines.h @@ -160,6 +160,7 @@ static void prof_out(void) { #define WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE 60 #define FULL_MATCHED_WORDS_PROMOTION_RATE 120 #define WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE 90 +#define WORDS_WITH_MATCH_SKIP_PROMOTION_RATE 105 // This should be greater than or equal to MAX_WORD_LENGTH defined in BinaryDictionary.java // This is only used for the size of array. Not to be used in c functions. diff --git a/native/src/proximity_info.h b/native/src/proximity_info.h index 5034c3b89..d9ed46f5b 100644 --- a/native/src/proximity_info.h +++ b/native/src/proximity_info.h @@ -46,6 +46,7 @@ public: ProximityType getMatchedProximityId( const int index, const unsigned short c, const bool checkProximityChars) const; bool sameAsTyped(const unsigned short *word, int length) const; + private: int getStartIndexFromCoordinates(const int x, const int y) const; const int MAX_PROXIMITY_CHARS_SIZE;