diff --git a/native/jni/src/char_utils.h b/native/jni/src/char_utils.h index b42de6607..c632b79b8 100644 --- a/native/jni/src/char_utils.h +++ b/native/jni/src/char_utils.h @@ -55,7 +55,7 @@ inline static int toBaseCodePoint(int c) { return c; } -inline static int toLowerCase(const int c) { +AK_FORCE_INLINE static int toLowerCase(const int c) { if (isAsciiUpper(c)) { return toAsciiLower(c); } else if (isAscii(c)) { @@ -64,7 +64,7 @@ inline static int toLowerCase(const int c) { return static_cast(latin_tolower(static_cast(c))); } -inline static int toBaseLowerCase(const int c) { +AK_FORCE_INLINE static int toBaseLowerCase(const int c) { return toLowerCase(toBaseCodePoint(c)); } diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp index d7b67c98b..50f33fe23 100644 --- a/native/jni/src/correction.cpp +++ b/native/jni/src/correction.cpp @@ -60,29 +60,6 @@ inline static void dumpEditDistance10ForDebug(int *editDistanceTable, } } -inline static void calcEditDistanceOneStep(int *editDistanceTable, const int *input, - const int inputSize, const int *output, const int outputLength) { - // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH_INTERNAL] is not touched. - // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j]. - // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated, - // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize]. - int *const current = editDistanceTable + outputLength * (inputSize + 1); - const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1); - const int *const prevprev = - outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0; - current[0] = outputLength; - const int co = toBaseLowerCase(output[outputLength - 1]); - const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0; - for (int i = 1; i <= inputSize; ++i) { - const int ci = toBaseLowerCase(input[i - 1]); - const uint16_t cost = (ci == co) ? 0 : 1; - current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost)); - if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) { - current[i] = min(current[i], prevprev[i - 2] + 1); - } - } -} - inline static int getCurrentEditDistance(int *editDistanceTable, const int editDistanceTableWidth, const int outputLength, const int inputSize) { if (DEBUG_EDIT_DISTANCE) { @@ -91,14 +68,6 @@ inline static int getCurrentEditDistance(int *editDistanceTable, const int editD return editDistanceTable[(editDistanceTableWidth + 1) * (outputLength) + inputSize]; } -////////////////////// -// inline functions // -////////////////////// -inline bool Correction::isSingleQuote(const int c) { - const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex); - return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE); -} - //////////////// // Correction // //////////////// @@ -174,17 +143,6 @@ int Correction::getFinalProbabilityForSubQueue(const int probability, int **word return getFinalProbabilityInternal(probability, word, wordLength, inputSize); } -int Correction::getFinalProbabilityInternal(const int probability, int **word, int *wordLength, - const int inputSize) { - const int outputIndex = mTerminalOutputIndex; - const int inputIndex = mTerminalInputIndex; - *wordLength = outputIndex + 1; - *word = mWord; - int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability( - inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize); - return finalProbability; -} - bool Correction::initProcessState(const int outputIndex) { if (mCorrectionStates[outputIndex].mChildCount <= 0) { return false; @@ -228,42 +186,6 @@ int Correction::getInputIndex() const { return mInputIndex; } -void Correction::incrementInputIndex() { - ++mInputIndex; -} - -void Correction::incrementOutputIndex() { - ++mOutputIndex; - mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex; - mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount; - mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; - mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; - mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; - - mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount; - mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount; - mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount; - mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount; - mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; - - mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos; - mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos; - mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos; - - mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded; - - mCorrectionStates[mOutputIndex].mMatching = mMatching; - mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching; - mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching; - mCorrectionStates[mOutputIndex].mTransposing = mTransposing; - mCorrectionStates[mOutputIndex].mExceeding = mExceeding; - mCorrectionStates[mOutputIndex].mSkipping = mSkipping; -} - -void Correction::startToTraverseAllNodes() { - mNeedsToTraverseAllNodes = true; -} - bool Correction::needsToPrune() const { // TODO: use edit distance here return mOutputIndex - 1 >= mMaxDepth || mProximityCount > mMaxEditDistance @@ -271,39 +193,11 @@ bool Correction::needsToPrune() const { || (!mDoAutoCompletion && (mOutputIndex > mInputSize)); } -void Correction::addCharToCurrentWord(const int c) { - mWord[mOutputIndex] = c; - const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord(); - calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord, - mOutputIndex + 1); -} - -Correction::CorrectionType Correction::processSkipChar(const int c, const bool isTerminal, - const bool inputIndexIncremented) { - addCharToCurrentWord(c); - mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0); - mTerminalOutputIndex = mOutputIndex; - if (mNeedsToTraverseAllNodes && isTerminal) { - incrementOutputIndex(); - return TRAVERSE_ALL_ON_TERMINAL; - } else { - incrementOutputIndex(); - return TRAVERSE_ALL_NOT_ON_TERMINAL; - } -} - -Correction::CorrectionType Correction::processUnrelatedCorrectionType() { - // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType - mTerminalInputIndex = mInputIndex; - mTerminalOutputIndex = mOutputIndex; - return UNRELATED; -} - -inline bool isEquivalentChar(ProximityType type) { +inline static bool isEquivalentChar(ProximityType type) { return type == EQUIVALENT_CHAR; } -inline bool isProximityCharOrEquivalentChar(ProximityType type) { +inline static bool isProximityCharOrEquivalentChar(ProximityType type) { return type == EQUIVALENT_CHAR || type == NEAR_PROXIMITY_CHAR; } @@ -625,29 +519,6 @@ Correction::CorrectionType Correction::processCharAndCalcState(const int c, cons } } -/* static */ int Correction::powerIntCapped(const int base, const int n) { - if (n <= 0) return 1; - if (base == 2) { - return n < 31 ? 1 << n : S_INT_MAX; - } else { - int ret = base; - for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret); - return ret; - } -} - -/* static */ void Correction::multiplyRate(const int rate, int *freq) { - if (*freq != S_INT_MAX) { - if (*freq > 1000000) { - *freq /= 100; - multiplyIntCapped(rate, freq); - } else { - multiplyIntCapped(rate, freq); - *freq /= 100; - } - } -} - inline static int getQuoteCount(const int *word, const int length) { int quoteCount = 0; for (int i = 0; i < length; ++i) { diff --git a/native/jni/src/correction.h b/native/jni/src/correction.h index d0b196cf2..105a95f38 100644 --- a/native/jni/src/correction.h +++ b/native/jni/src/correction.h @@ -145,7 +145,7 @@ class Correction { } static const int TWO_31ST_DIV_2 = S_INT_MAX / 2; - inline static void multiplyIntCapped(const int multiplier, int *base) { + AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) { const int temp = *base; if (temp != S_INT_MAX) { // Branch if multiplier == 2 for the optimization @@ -168,8 +168,28 @@ class Correction { } } - static int powerIntCapped(const int base, const int n); - static void multiplyRate(const int rate, int *freq); + AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) { + if (n <= 0) return 1; + if (base == 2) { + return n < 31 ? 1 << n : S_INT_MAX; + } else { + int ret = base; + for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret); + return ret; + } + } + + AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) { + if (*freq != S_INT_MAX) { + if (*freq > 1000000) { + *freq /= 100; + multiplyIntCapped(rate, freq); + } else { + multiplyIntCapped(rate, freq); + *freq /= 100; + } + } + } inline int getSpaceProximityPos() const { return mSpaceProximityPos; @@ -194,6 +214,8 @@ class Correction { inline void incrementOutputIndex(); inline void startToTraverseAllNodes(); inline bool isSingleQuote(const int c); + inline CorrectionType processSkipChar(const int c, const bool isTerminal, + const bool inputIndexIncremented); inline CorrectionType processUnrelatedCorrectionType(); inline void addCharToCurrentWord(const int c); inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength, @@ -224,9 +246,6 @@ class Correction { // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot. int mEditDistanceTable[(MAX_WORD_LENGTH_INTERNAL + 1) * (MAX_WORD_LENGTH_INTERNAL + 1)]; - CorrectionType processSkipChar(const int c, const bool isTerminal, - const bool inputIndexIncremented); - CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL]; // The following member variables are being used as cache values of the correction state. @@ -254,5 +273,109 @@ class Correction { bool mSkipping; ProximityInfoState mProximityInfoState; }; + +inline void Correction::incrementInputIndex() { + ++mInputIndex; +} + +AK_FORCE_INLINE void Correction::incrementOutputIndex() { + ++mOutputIndex; + mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex; + mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount; + mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; + mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; + mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; + + mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount; + mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount; + mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount; + mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount; + mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; + + mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos; + mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos; + mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos; + + mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded; + + mCorrectionStates[mOutputIndex].mMatching = mMatching; + mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching; + mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching; + mCorrectionStates[mOutputIndex].mTransposing = mTransposing; + mCorrectionStates[mOutputIndex].mExceeding = mExceeding; + mCorrectionStates[mOutputIndex].mSkipping = mSkipping; +} + +inline void Correction::startToTraverseAllNodes() { + mNeedsToTraverseAllNodes = true; +} + +inline bool Correction::isSingleQuote(const int c) { + const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex); + return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE); +} + +AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c, + const bool isTerminal, const bool inputIndexIncremented) { + addCharToCurrentWord(c); + mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0); + mTerminalOutputIndex = mOutputIndex; + if (mNeedsToTraverseAllNodes && isTerminal) { + incrementOutputIndex(); + return TRAVERSE_ALL_ON_TERMINAL; + } else { + incrementOutputIndex(); + return TRAVERSE_ALL_NOT_ON_TERMINAL; + } +} + +inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() { + // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType + mTerminalInputIndex = mInputIndex; + mTerminalOutputIndex = mOutputIndex; + return UNRELATED; +} + +AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input, + const int inputSize, const int *output, const int outputLength) { + // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH_INTERNAL] is not touched. + // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j]. + // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated, + // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize]. + int *const current = editDistanceTable + outputLength * (inputSize + 1); + const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1); + const int *const prevprev = + outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0; + current[0] = outputLength; + const int co = toBaseLowerCase(output[outputLength - 1]); + const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0; + for (int i = 1; i <= inputSize; ++i) { + const int ci = toBaseLowerCase(input[i - 1]); + const uint16_t cost = (ci == co) ? 0 : 1; + current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost)); + if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) { + current[i] = min(current[i], prevprev[i - 2] + 1); + } + } +} + +AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) { + mWord[mOutputIndex] = c; + const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord(); + calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord, + mOutputIndex + 1); +} + +inline int Correction::getFinalProbabilityInternal(const int probability, int **word, + int *wordLength, const int inputSize) { + const int outputIndex = mTerminalOutputIndex; + const int inputIndex = mTerminalInputIndex; + *wordLength = outputIndex + 1; + *word = mWord; + int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability( + inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize); + return finalProbability; +} + } // namespace latinime #endif // LATINIME_CORRECTION_H diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 095487416..40bc958d1 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -383,6 +383,12 @@ template inline T max(T a, T b) { return a > b ? a : b; } #define NELEMS(x) (sizeof(x) / sizeof((x)[0])) +#ifdef __GNUC__ +#define AK_FORCE_INLINE __attribute__((always_inline)) __inline__ +#else // __GNUC__ +#define AK_FORCE_INLINE inline +#endif // __GNUC__ + // The ratio of neutral area radius to sweet spot radius. #define NEUTRAL_AREA_RADIUS_RATIO 1.3f diff --git a/native/jni/src/unigram_dictionary.cpp b/native/jni/src/unigram_dictionary.cpp index 820f9ab12..d134a47e6 100644 --- a/native/jni/src/unigram_dictionary.cpp +++ b/native/jni/src/unigram_dictionary.cpp @@ -365,7 +365,7 @@ void UnigramDictionary::getSuggestionCandidates(const bool useFullEditDistance, } } -inline void UnigramDictionary::onTerminal(const int probability, +void UnigramDictionary::onTerminal(const int probability, const TerminalAttributes& terminalAttributes, Correction *correction, WordsPriorityQueuePool *queuePool, const bool addToMasterQueue, const int currentWordIndex) const {