/* * Copyright (C) 2011 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LATINIME_CORRECTION_H #define LATINIME_CORRECTION_H #include // for memset() #include "correction_state.h" #include "defines.h" #include "proximity_info_state.h" namespace latinime { class ProximityInfo; class Correction { public: typedef enum { TRAVERSE_ALL_ON_TERMINAL, TRAVERSE_ALL_NOT_ON_TERMINAL, UNRELATED, ON_TERMINAL, NOT_ON_TERMINAL } CorrectionType; Correction() : mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false), mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0), mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0), mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0), mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0), mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0), mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false), mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false), mSkipping(false), mProximityInfoState() { memset(mWord, 0, sizeof(mWord)); memset(mDistances, 0, sizeof(mDistances)); memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable)); // NOTE: mCorrectionStates is an array of instances. // No need to initialize it explicitly here. } // Non virtual inline destructor -- never inherit this class ~Correction() {} void resetCorrection(); void initCorrection(const ProximityInfo *pi, const int inputSize, const int maxDepth); void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll); // TODO: remove void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance, const bool doAutoCompletion, const int maxErrors); void checkState() const; bool sameAsTyped() const; bool initProcessState(const int index); int getInputIndex() const; bool needsToPrune() const; int pushAndGetTotalTraverseCount() { return ++mTotalTraverseCount; } int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, const int wordCount, const bool isSpaceProximity, const int *word) const; int getFinalProbability(const int probability, int **word, int *wordLength); int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength, const int inputSize); CorrectionType processCharAndCalcState(const int c, const bool isTerminal); ///////////////////////// // Tree helper methods int goDownTree(const int parentIndex, const int childCount, const int firstChildPos); inline int getTreeSiblingPos(const int index) const { return mCorrectionStates[index].mSiblingPos; } inline void setTreeSiblingPos(const int index, const int pos) { mCorrectionStates[index].mSiblingPos = pos; } inline int getTreeParentIndex(const int index) const { return mCorrectionStates[index].mParentIndex; } class RankingAlgorithm { public: static int calculateFinalProbability(const int inputIndex, const int depth, const int probability, int *editDistanceTable, const Correction *correction, const int inputSize); static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, const int wordCount, const Correction *correction, const bool isSpaceProximity, const int *word); static float calcNormalizedScore(const int *before, const int beforeLength, const int *after, const int afterLength, const int score); static int editDistance(const int *before, const int beforeLength, const int *after, const int afterLength); private: static const int MAX_INITIAL_SCORE = 255; }; // proximity info state void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes, const int inputSize, const int *xCoordinates, const int *yCoordinates) { mProximityInfoState.initInputParams(0, MAX_POINT_TO_KEY_LENGTH, proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false); } const int *getPrimaryInputWord() const { return mProximityInfoState.getPrimaryInputWord(); } int getPrimaryCodePointAt(const int index) const { return mProximityInfoState.getPrimaryCodePointAt(index); } private: DISALLOW_COPY_AND_ASSIGN(Correction); ///////////////////////// // static inline utils // ///////////////////////// static const int TWO_31ST_DIV_255 = S_INT_MAX / 255; static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) { return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX); } static const int TWO_31ST_DIV_2 = S_INT_MAX / 2; AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) { const int temp = *base; if (temp != S_INT_MAX) { // Branch if multiplier == 2 for the optimization if (multiplier < 0) { if (DEBUG_DICT) { ASSERT(false); } AKLOGI("--- Invalid multiplier: %d", multiplier); } else if (multiplier == 0) { *base = 0; } else if (multiplier == 2) { *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX; } else { // TODO: This overflow check gives a wrong answer when, for example, // temp = 2^16 + 1 and multiplier = 2^17 + 1. // Fix this behavior. const int tempRetval = temp * multiplier; *base = tempRetval >= temp ? tempRetval : S_INT_MAX; } } } AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) { if (n <= 0) return 1; if (base == 2) { return n < 31 ? 1 << n : S_INT_MAX; } int ret = base; for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret); return ret; } AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) { if (*freq != S_INT_MAX) { if (*freq > 1000000) { *freq /= 100; multiplyIntCapped(rate, freq); } else { multiplyIntCapped(rate, freq); *freq /= 100; } } } inline int getSpaceProximityPos() const { return mSpaceProximityPos; } inline int getMissingSpacePos() const { return mMissingSpacePos; } inline int getSkipPos() const { return mSkipPos; } inline int getExcessivePos() const { return mExcessivePos; } inline int getTransposedPos() const { return mTransposedPos; } inline void incrementInputIndex(); inline void incrementOutputIndex(); inline void startToTraverseAllNodes(); inline bool isSingleQuote(const int c); inline CorrectionType processSkipChar(const int c, const bool isTerminal, const bool inputIndexIncremented); inline CorrectionType processUnrelatedCorrectionType(); inline void addCharToCurrentWord(const int c); inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength, const int inputSize); static const int TYPED_LETTER_MULTIPLIER = 2; static const int FULL_WORD_MULTIPLIER = 2; const ProximityInfo *mProximityInfo; bool mUseFullEditDistance; bool mDoAutoCompletion; int mMaxEditDistance; int mMaxDepth; int mInputSize; int mSpaceProximityPos; int mMissingSpacePos; int mTerminalInputIndex; int mTerminalOutputIndex; int mMaxErrors; int mTotalTraverseCount; // The following arrays are state buffer. int mWord[MAX_WORD_LENGTH]; int mDistances[MAX_WORD_LENGTH]; // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N. // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot. int mEditDistanceTable[(MAX_WORD_LENGTH + 1) * (MAX_WORD_LENGTH + 1)]; CorrectionState mCorrectionStates[MAX_WORD_LENGTH]; // The following member variables are being used as cache values of the correction state. bool mNeedsToTraverseAllNodes; int mOutputIndex; int mInputIndex; int mEquivalentCharCount; int mProximityCount; int mExcessiveCount; int mTransposedCount; int mSkippedCount; int mTransposedPos; int mExcessivePos; int mSkipPos; bool mLastCharExceeded; bool mMatching; bool mProximityMatching; bool mAdditionalProximityMatching; bool mExceeding; bool mTransposing; bool mSkipping; ProximityInfoState mProximityInfoState; }; inline void Correction::incrementInputIndex() { ++mInputIndex; } AK_FORCE_INLINE void Correction::incrementOutputIndex() { ++mOutputIndex; mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex; mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount; mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount; mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount; mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount; mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount; mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos; mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos; mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos; mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded; mCorrectionStates[mOutputIndex].mMatching = mMatching; mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching; mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching; mCorrectionStates[mOutputIndex].mTransposing = mTransposing; mCorrectionStates[mOutputIndex].mExceeding = mExceeding; mCorrectionStates[mOutputIndex].mSkipping = mSkipping; } inline void Correction::startToTraverseAllNodes() { mNeedsToTraverseAllNodes = true; } inline bool Correction::isSingleQuote(const int c) { const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex); return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE); } AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c, const bool isTerminal, const bool inputIndexIncremented) { addCharToCurrentWord(c); mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0); mTerminalOutputIndex = mOutputIndex; incrementOutputIndex(); if (mNeedsToTraverseAllNodes && isTerminal) { return TRAVERSE_ALL_ON_TERMINAL; } return TRAVERSE_ALL_NOT_ON_TERMINAL; } inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() { // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType mTerminalInputIndex = mInputIndex; mTerminalOutputIndex = mOutputIndex; return UNRELATED; } AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input, const int inputSize, const int *output, const int outputLength) { // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH] is not touched. // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j]. // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated, // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize]. int *const current = editDistanceTable + outputLength * (inputSize + 1); const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1); const int *const prevprev = outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0; current[0] = outputLength; const int co = toBaseLowerCase(output[outputLength - 1]); const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0; for (int i = 1; i <= inputSize; ++i) { const int ci = toBaseLowerCase(input[i - 1]); const int cost = (ci == co) ? 0 : 1; current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost)); if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) { current[i] = min(current[i], prevprev[i - 2] + 1); } } } AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) { mWord[mOutputIndex] = c; const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord(); calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord, mOutputIndex + 1); } inline int Correction::getFinalProbabilityInternal(const int probability, int **word, int *wordLength, const int inputSize) { const int outputIndex = mTerminalOutputIndex; const int inputIndex = mTerminalInputIndex; *wordLength = outputIndex + 1; *word = mWord; int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability( inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize); return finalProbability; } } // namespace latinime #endif // LATINIME_CORRECTION_H