diff --git a/native/src/correction.cpp b/native/src/correction.cpp index f8f73ddf5..a4090a966 100644 --- a/native/src/correction.cpp +++ b/native/src/correction.cpp @@ -21,6 +21,7 @@ #define LOG_TAG "LatinIME: correction.cpp" #include "correction.h" +#include "dictionary.h" #include "proximity_info.h" namespace latinime { @@ -93,16 +94,11 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen return -1; } - // TODO: Remove this - if (mSkipPos >= 0 && mSkippedCount <= 0) { - return -1; - } - *word = mWord; const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2) : (mInputLength == inputIndex + 1); return Correction::RankingAlgorithm::calculateFinalFreq( - inputIndex, outputIndex, freq, sameLength, this); + inputIndex, outputIndex, freq, sameLength, mEditDistanceTable, this); } bool Correction::initProcessState(const int outputIndex) { @@ -117,6 +113,7 @@ bool Correction::initProcessState(const int outputIndex) { mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount; mSkipPos = mCorrectionStates[outputIndex].mSkipPos; mSkipping = false; + mProximityMatching = false; mMatching = false; return true; } @@ -160,6 +157,7 @@ void Correction::incrementOutputIndex() { mCorrectionStates[mOutputIndex].mSkipping = mSkipping; mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos; mCorrectionStates[mOutputIndex].mMatching = mMatching; + mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching; } void Correction::startToTraverseAllNodes() { @@ -207,6 +205,20 @@ Correction::CorrectionType Correction::processCharAndCalcState( } if (mNeedsToTraverseAllNodes || isQuote(c)) { + const bool checkProximityChars = + !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0); + // Note: This logic tries saving cases like contrst --> contrast -- "a" is one of + // proximity chars of "s", but it should rather be handled as a skipped char. + if (checkProximityChars + && mInputIndex > 0 + && mCorrectionStates[mOutputIndex].mProximityMatching + && mCorrectionStates[mOutputIndex].mSkipping + && mProximityInfo->getMatchedProximityId( + mInputIndex - 1, c, false) + == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) { + ++mSkippedCount; + --mProximityCount; + } return processSkipChar(c, isTerminal); } else { int inputIndexForProximity = mInputIndex; @@ -220,16 +232,27 @@ Correction::CorrectionType Correction::processCharAndCalcState( } } + // TODO: sum counters const bool checkProximityChars = - !(mSkipPos >= 0 || mExcessivePos >= 0 || mTransposedPos >= 0); + !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0); int matchedProximityCharId = mProximityInfo->getMatchedProximityId( inputIndexForProximity, c, checkProximityChars); if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) { - if (skip) { + if (skip && mProximityCount == 0) { // Skip this letter and continue deeper ++mSkippedCount; return processSkipChar(c, isTerminal); + } else if (checkProximityChars + && inputIndexForProximity > 0 + && mCorrectionStates[mOutputIndex].mProximityMatching + && mCorrectionStates[mOutputIndex].mSkipping + && mProximityInfo->getMatchedProximityId( + inputIndexForProximity - 1, c, false) + == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) { + ++mSkippedCount; + --mProximityCount; + return processSkipChar(c, isTerminal); } else { return UNRELATED; } @@ -238,6 +261,7 @@ Correction::CorrectionType Correction::processCharAndCalcState( // proximity chars. So, we don't need to check proximity. mMatching = true; } else if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) { + mProximityMatching = true; incrementProximityCount(); } @@ -320,29 +344,116 @@ inline static void multiplyRate(const int rate, int *freq) { } } +/* static */ +inline static int editDistance( + int* editDistanceTable, const unsigned short* input, + const int inputLength, const unsigned short* output, const int outputLength) { + // dp[li][lo] dp[a][b] = dp[ a * lo + b] + int* dp = editDistanceTable; + const int li = inputLength + 1; + const int lo = outputLength + 1; + for (int i = 0; i < li; ++i) { + dp[lo * i] = i; + } + for (int i = 0; i < lo; ++i) { + dp[i] = i; + } + + for (int i = 0; i < li - 1; ++i) { + for (int j = 0; j < lo - 1; ++j) { + const uint32_t ci = Dictionary::toBaseLowerCase(input[i]); + const uint32_t co = Dictionary::toBaseLowerCase(output[j]); + const uint16_t cost = (ci == co) ? 0 : 1; + dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1, + min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost)); + if (li > 0 && lo > 0 + && ci == Dictionary::toBaseLowerCase(output[j - 1]) + && co == Dictionary::toBaseLowerCase(input[i - 1])) { + dp[(i + 1) * lo + (j + 1)] = min( + dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost); + } + } + } + + if (DEBUG_EDIT_DISTANCE) { + LOGI("IN = %d, OUT = %d", inputLength, outputLength); + for (int i = 0; i < li; ++i) { + for (int j = 0; j < lo; ++j) { + LOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]); + } + } + } + return dp[li * lo - 1]; +} + ////////////////////// // RankingAlgorithm // ////////////////////// /* static */ int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int outputIndex, - const int freq, const bool sameLength, const Correction* correction) { + const int freq, const bool sameLength, int* editDistanceTable, + const Correction* correction) { const int excessivePos = correction->getExcessivePos(); const int transposedPos = correction->getTransposedPos(); const int inputLength = correction->mInputLength; const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER; const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER; const ProximityInfo *proximityInfo = correction->mProximityInfo; + const int skipCount = correction->mSkippedCount; + const int proximityMatchedCount = correction->mProximityCount; // TODO: use mExcessiveCount - const int matchCount = inputLength - correction->mProximityCount - (excessivePos >= 0 ? 1 : 0); - const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); + int matchCount = inputLength - correction->mProximityCount - (excessivePos >= 0 ? 1 : 0); const unsigned short* word = correction->mWord; - const bool skipped = correction->mSkippedCount > 0; + const bool skipped = skipCount > 0; + + // ----- TODO: use edit distance here as follows? ---------------------- / + //if (!skipped && excessivePos < 0 && transposedPos < 0) { + // const int ed = editDistance(dp, proximityInfo->getInputWord(), + // inputLength, word, outputIndex + 1); + // matchCount = outputIndex + 1 - ed; + // if (ed == 1 && !sameLength) ++matchCount; + //} + // const int ed = editDistance(dp, proximityInfo->getInputWord(), + // inputLength, word, outputIndex + 1); + // if (ed == 1 && !sameLength) ++matchCount; ------------------------ / + int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); // TODO: Demote by edit distance int finalFreq = freq * matchWeight; + // +1 +11/-12 + /*if (inputLength == outputIndex && !skipped && excessivePos < 0 && transposedPos < 0) { + const int ed = editDistance(dp, proximityInfo->getInputWord(), + inputLength, word, outputIndex + 1); + if (ed == 1) { + multiplyRate(160, &finalFreq); + } + }*/ + if (inputLength == outputIndex && excessivePos < 0 && transposedPos < 0 + && (proximityMatchedCount > 0 || skipped)) { + const int ed = editDistance(editDistanceTable, proximityInfo->getPrimaryInputWord(), + inputLength, word, outputIndex + 1); + if (ed == 1) { + multiplyRate(160, &finalFreq); + } + } + + // TODO: Promote properly? + //if (skipCount == 1 && excessivePos < 0 && transposedPos < 0 && inputLength == outputIndex + // && !sameLength) { + // multiplyRate(150, &finalFreq); + //} + //if (skipCount == 0 && excessivePos < 0 && transposedPos < 0 && inputLength == outputIndex + // && !sameLength) { + // multiplyRate(150, &finalFreq); + //} + //if (skipCount == 0 && excessivePos < 0 && transposedPos < 0 + // && inputLength == outputIndex + 1) { + // multiplyRate(150, &finalFreq); + //} + if (skipped) { if (inputLength >= 2) { const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE @@ -389,7 +500,7 @@ int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const multiplyIntCapped(typedLetterMultiplier, &finalFreq); multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); } - if (DEBUG_DICT) { + if (DEBUG_DICT_FULL) { LOGI("calc: %d, %d", outputIndex, sameLength); } if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq); diff --git a/native/src/correction.h b/native/src/correction.h index 2fa8c905d..9d385a44e 100644 --- a/native/src/correction.h +++ b/native/src/correction.h @@ -120,6 +120,8 @@ private: int mTerminalInputIndex; int mTerminalOutputIndex; unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; + // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot. + int mEditDistanceTable[MAX_WORD_LENGTH_INTERNAL * MAX_WORD_LENGTH_INTERNAL]; CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL]; @@ -132,11 +134,13 @@ private: bool mNeedsToTraverseAllNodes; bool mMatching; bool mSkipping; + bool mProximityMatching; class RankingAlgorithm { public: static int calculateFinalFreq(const int inputIndex, const int depth, - const int freq, const bool sameLength, const Correction* correction); + const int freq, const bool sameLength, int *editDistanceTable, + const Correction* correction); static int calcFreqForSplitTwoWords(const int firstFreq, const int secondFreq, const Correction* correction); }; diff --git a/native/src/correction_state.h b/native/src/correction_state.h index d30d13c85..267deda9b 100644 --- a/native/src/correction_state.h +++ b/native/src/correction_state.h @@ -33,6 +33,7 @@ struct CorrectionState { int8_t mSkipPos; // should be signed bool mMatching; bool mSkipping; + bool mProximityMatching; bool mNeedsToTraverseAllNodes; }; @@ -47,6 +48,7 @@ inline static void initCorrectionState(CorrectionState *state, const int rootPos state->mSkippedCount = 0; state->mMatching = false; state->mSkipping = false; + state->mProximityMatching = false; state->mNeedsToTraverseAllNodes = traverseAll; state->mSkipPos = -1; } diff --git a/native/src/defines.h b/native/src/defines.h index c1838d341..c1d08e695 100644 --- a/native/src/defines.h +++ b/native/src/defines.h @@ -94,20 +94,36 @@ static void prof_out(void) { #endif #define DEBUG_DICT true #define DEBUG_DICT_FULL false +#define DEBUG_EDIT_DISTANCE false #define DEBUG_SHOW_FOUND_WORD DEBUG_DICT_FULL #define DEBUG_NODE DEBUG_DICT_FULL #define DEBUG_TRACE DEBUG_DICT_FULL #define DEBUG_PROXIMITY_INFO true +#define DUMP_WORD(word, length) do { dumpWord(word, length); } while(0) + +static char charBuf[50]; + +static void dumpWord(const unsigned short* word, const int length) { + for (int i = 0; i < length; ++i) { + charBuf[i] = word[i]; + } + charBuf[length] = 0; + LOGI("[ %s ]", charBuf); +} + #else // FLAG_DBG #define DEBUG_DICT false #define DEBUG_DICT_FULL false +#define DEBUG_EDIT_DISTANCE false #define DEBUG_SHOW_FOUND_WORD false #define DEBUG_NODE false #define DEBUG_TRACE false #define DEBUG_PROXIMITY_INFO false +#define DUMP_WORD(word, length) + #endif // FLAG_DBG #ifndef U_SHORT_MAX diff --git a/native/src/proximity_info.cpp b/native/src/proximity_info.cpp index d437e251a..361bdacbf 100644 --- a/native/src/proximity_info.cpp +++ b/native/src/proximity_info.cpp @@ -68,6 +68,10 @@ bool ProximityInfo::hasSpaceProximity(const int x, const int y) const { void ProximityInfo::setInputParams(const int* inputCodes, const int inputLength) { mInputCodes = inputCodes; mInputLength = inputLength; + for (int i = 0; i < inputLength; ++i) { + mPrimaryInputWord[i] = getPrimaryCharAt(i); + } + mPrimaryInputWord[inputLength] = 0; } inline const int* ProximityInfo::getProximityCharsAt(const int index) const { diff --git a/native/src/proximity_info.h b/native/src/proximity_info.h index d9ed46f5b..75fc8fb63 100644 --- a/native/src/proximity_info.h +++ b/native/src/proximity_info.h @@ -46,6 +46,9 @@ public: ProximityType getMatchedProximityId( const int index, const unsigned short c, const bool checkProximityChars) const; bool sameAsTyped(const unsigned short *word, int length) const; + const unsigned short* getPrimaryInputWord() const { + return mPrimaryInputWord; + } private: int getStartIndexFromCoordinates(const int x, const int y) const; @@ -59,6 +62,7 @@ private: const int *mInputCodes; uint32_t *mProximityCharsArray; int mInputLength; + unsigned short mPrimaryInputWord[MAX_WORD_LENGTH_INTERNAL]; }; } // namespace latinime diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index 6517bc0b8..6bc350505 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -187,8 +187,9 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, mCorrection->initCorrection(mProximityInfo, mInputLength, maxDepth); PROF_END(0); + // TODO: remove PROF_START(1); - getSuggestionCandidates(-1, -1, -1); + // Note: This line is intentionally left blank PROF_END(1); PROF_START(2);