diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp index 5f630f033..e892c8591 100644 --- a/native/jni/src/correction.cpp +++ b/native/jni/src/correction.cpp @@ -22,6 +22,7 @@ #include "correction.h" #include "defines.h" #include "proximity_info_state.h" +#include "suggest_utils.h" namespace latinime { @@ -673,27 +674,9 @@ inline static bool isUpperCase(unsigned short c) { if (i < adjustedProximityMatchedCount) { multiplyIntCapped(typedLetterMultiplier, &finalFreq); } - if (squaredDistance >= 0) { - // Promote or demote the score according to the distance from the sweet spot - static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f; - static const float B = 1.0f; - static const float C = 0.5f; - static const float MIN = 0.3f; - static const float R1 = NEUTRAL_SCORE_SQUARED_RADIUS; - static const float R2 = HALF_SCORE_SQUARED_RADIUS; - const float x = static_cast(squaredDistance) - / ProximityInfoState::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR; - const float factor = max((x < R1) - ? (A * (R1 - x) + B * x) / R1 - : (B * (R2 - x) + C * (x - R1)) / (R2 - R1), MIN); - // factor is a piecewise linear function like: - // A -_ . - // ^-_ . - // B \ . - // \_ . - // C ------------. - // . - // 0 R1 R2 . + const float factor = + SuggestUtils::getDistanceScalingFactor(static_cast(squaredDistance)); + if (factor > 0.0f) { multiplyRate((int)(factor * 100.0f), &finalFreq); } else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) { multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/proximity_info_state.cpp index d0cc4acc2..aa029297e 100644 --- a/native/jni/src/proximity_info_state.cpp +++ b/native/jni/src/proximity_info_state.cpp @@ -101,7 +101,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi mTimes.clear(); mInputIndice.clear(); mLengthCache.clear(); - mDistanceCache.clear(); + mDistanceCache_G.clear(); mNearKeysVector.clear(); mSearchKeysVector.clear(); mSpeedRates.clear(); @@ -210,7 +210,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi const int keyCount = mProximityInfo->getKeyCount(); mNearKeysVector.resize(mSampledInputSize); mSearchKeysVector.resize(mSampledInputSize); - mDistanceCache.resize(mSampledInputSize * keyCount); + mDistanceCache_G.resize(mSampledInputSize * keyCount); for (int i = lastSavedInputSize; i < mSampledInputSize; ++i) { mNearKeysVector[i].reset(); mSearchKeysVector[i].reset(); @@ -221,7 +221,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi const int y = mSampledInputYs[i]; const float normalizedSquaredDistance = mProximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y); - mDistanceCache[index] = normalizedSquaredDistance; + mDistanceCache_G[index] = normalizedSquaredDistance; if (normalizedSquaredDistance < NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { mNearKeysVector[i][k] = true; } @@ -686,7 +686,7 @@ float ProximityInfoState::getPointToKeyLength( const int keyId = mProximityInfo->getKeyIndexOf(codePoint); if (keyId != NOT_AN_INDEX) { const int index = inputIndex * mProximityInfo->getKeyCount() + keyId; - return min(mDistanceCache[index] * scale, mMaxPointToKeyLength); + return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength); } if (isSkippableCodePoint(codePoint)) { return 0.0f; @@ -695,7 +695,7 @@ float ProximityInfoState::getPointToKeyLength( return MAX_POINT_TO_KEY_LENGTH; } -float ProximityInfoState::getPointToKeyLength(const int inputIndex, const int codePoint) const { +float ProximityInfoState::getPointToKeyLength_G(const int inputIndex, const int codePoint) const { return getPointToKeyLength(inputIndex, codePoint, 1.0f); } @@ -706,7 +706,7 @@ float ProximityInfoState::getPointToKeyByIdLength( const int inputIndex, const int keyId, const float scale) const { if (keyId != NOT_AN_INDEX) { const int index = inputIndex * mProximityInfo->getKeyCount() + keyId; - return min(mDistanceCache[index] * scale, mMaxPointToKeyLength); + return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength); } // If the char is not a key on the keyboard then return the max length. return static_cast(MAX_POINT_TO_KEY_LENGTH); diff --git a/native/jni/src/proximity_info_state.h b/native/jni/src/proximity_info_state.h index c8d417dd0..d747bae2a 100644 --- a/native/jni/src/proximity_info_state.h +++ b/native/jni/src/proximity_info_state.h @@ -58,7 +58,7 @@ class ProximityInfoState { mHasTouchPositionCorrectionData(false), mMostCommonKeyWidthSquare(0), mLocaleStr(), mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0), mIsContinuationPossible(false), mSampledInputXs(), mSampledInputYs(), mTimes(), - mInputIndice(), mLengthCache(), mBeelineSpeedPercentiles(), mDistanceCache(), + mInputIndice(), mLengthCache(), mBeelineSpeedPercentiles(), mDistanceCache_G(), mSpeedRates(), mDirections(), mCharProbabilities(), mNearKeysVector(), mSearchKeysVector(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0) { memset(mInputCodes, 0, sizeof(mInputCodes)); @@ -157,7 +157,7 @@ class ProximityInfoState { float getPointToKeyByIdLength(const int inputIndex, const int keyId, const float scale) const; float getPointToKeyByIdLength(const int inputIndex, const int keyId) const; float getPointToKeyLength(const int inputIndex, const int codePoint, const float scale) const; - float getPointToKeyLength(const int inputIndex, const int codePoint) const; + float getPointToKeyLength_G(const int inputIndex, const int codePoint) const; ProximityType getMatchedProximityId(const int index, const int c, const bool checkProximityChars, int *proximityIndex = 0) const; @@ -274,7 +274,7 @@ class ProximityInfoState { std::vector mInputIndice; std::vector mLengthCache; std::vector mBeelineSpeedPercentiles; - std::vector mDistanceCache; + std::vector mDistanceCache_G; std::vector mSpeedRates; std::vector mDirections; // probabilities of skipping or mapping to a key for each point. diff --git a/native/jni/src/suggest_utils.h b/native/jni/src/suggest_utils.h new file mode 100644 index 000000000..a3b3c12dd --- /dev/null +++ b/native/jni/src/suggest_utils.h @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SUGGEST_UTILS_H +#define LATINIME_SUGGEST_UTILS_H + +#include "defines.h" + +namespace latinime { +class SuggestUtils { + public: + static float getDistanceScalingFactor(float normalizedSquaredDistance) { + if (normalizedSquaredDistance < 0.0f) { + return -1.0f; + } + // Promote or demote the score according to the distance from the sweet spot + static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f; + static const float B = 1.0f; + static const float C = 0.5f; + static const float MIN = 0.3f; + static const float R1 = NEUTRAL_SCORE_SQUARED_RADIUS; + static const float R2 = HALF_SCORE_SQUARED_RADIUS; + const float x = static_cast(normalizedSquaredDistance) + / ProximityInfoState::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR; + const float factor = max((x < R1) + ? (A * (R1 - x) + B * x) / R1 + : (B * (R2 - x) + C * (x - R1)) / (R2 - R1), MIN); + // factor is a piecewise linear function like: + // A -_ . + // ^-_ . + // B \ . + // \_ . + // C ------------. + // . + // 0 R1 R2 . + return factor; + } +}; +} // namespace latinime +#endif // LATINIME_SUGGEST_UTILS_H