diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp index 76234f840..0c65939e0 100644 --- a/native/jni/src/correction.cpp +++ b/native/jni/src/correction.cpp @@ -675,7 +675,7 @@ inline static bool isUpperCase(unsigned short c) { multiplyIntCapped(typedLetterMultiplier, &finalFreq); } const float factor = - SuggestUtils::getDistanceScalingFactor(static_cast(squaredDistance)); + SuggestUtils::getLengthScalingFactor(static_cast(squaredDistance)); if (factor > 0.0f) { multiplyRate(static_cast(factor * 100.0f), &finalFreq); } else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) { diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/proximity_info_state.cpp index a10b260e1..cc5b736bd 100644 --- a/native/jni/src/proximity_info_state.cpp +++ b/native/jni/src/proximity_info_state.cpp @@ -81,7 +81,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi mSampledTimes.clear(); mSampledInputIndice.clear(); mSampledLengthCache.clear(); - mSampledDistanceCache_G.clear(); + mSampledNormalizedSquaredLengthCache.clear(); mSampledNearKeySets.clear(); mSampledSearchKeySets.clear(); mSpeedRates.clear(); @@ -122,14 +122,15 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (mSampledInputSize > 0) { ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize, lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs, - &mSampledNearKeySets, &mSampledDistanceCache_G); + &mSampledNearKeySets, &mSampledNormalizedSquaredLengthCache); if (isGeometric) { // updates probabilities of skipping or mapping each key for all points. ProximityInfoStateUtils::updateAlignPointProbabilities( mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(), mProximityInfo->getKeyCount(), lastSavedInputSize, mSampledInputSize, &mSampledInputXs, &mSampledInputYs, &mSpeedRates, &mSampledLengthCache, - &mSampledDistanceCache_G, &mSampledNearKeySets, &mCharProbabilities); + &mSampledNormalizedSquaredLengthCache, &mSampledNearKeySets, + &mCharProbabilities); ProximityInfoStateUtils::updateSampledSearchKeySets(mProximityInfo, mSampledInputSize, lastSavedInputSize, &mSampledLengthCache, &mSampledNearKeySets, &mSampledSearchKeySets, @@ -171,7 +172,7 @@ float ProximityInfoState::getPointToKeyLength( const int keyId = mProximityInfo->getKeyIndexOf(codePoint); if (keyId != NOT_AN_INDEX) { const int index = inputIndex * mProximityInfo->getKeyCount() + keyId; - return min(mSampledDistanceCache_G[index], mMaxPointToKeyLength); + return min(mSampledNormalizedSquaredLengthCache[index], mMaxPointToKeyLength); } if (isIntentionalOmissionCodePoint(codePoint)) { return 0.0f; @@ -183,7 +184,8 @@ float ProximityInfoState::getPointToKeyLength( float ProximityInfoState::getPointToKeyByIdLength( const int inputIndex, const int keyId) const { return ProximityInfoStateUtils::getPointToKeyByIdLength(mMaxPointToKeyLength, - &mSampledDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId); + &mSampledNormalizedSquaredLengthCache, mProximityInfo->getKeyCount(), inputIndex, + keyId); } // In the following function, c is the current character of the dictionary word currently examined. diff --git a/native/jni/src/proximity_info_state.h b/native/jni/src/proximity_info_state.h index 9bba751d0..bbe8af240 100644 --- a/native/jni/src/proximity_info_state.h +++ b/native/jni/src/proximity_info_state.h @@ -49,8 +49,8 @@ class ProximityInfoState { mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0), mIsContinuousSuggestionPossible(false), mSampledInputXs(), mSampledInputYs(), mSampledTimes(), mSampledInputIndice(), mSampledLengthCache(), - mBeelineSpeedPercentiles(), mSampledDistanceCache_G(), mSpeedRates(), mDirections(), - mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), + mBeelineSpeedPercentiles(), mSampledNormalizedSquaredLengthCache(), mSpeedRates(), + mDirections(), mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), mSampledSearchKeyVectors(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0), mMostProbableStringProbability(0.0f) { memset(mInputProximities, 0, sizeof(mInputProximities)); @@ -147,7 +147,9 @@ class ProximityInfoState { return mIsContinuousSuggestionPossible; } + // TODO: Rename s/Length/NormalizedSquaredLength/ float getPointToKeyByIdLength(const int inputIndex, const int keyId) const; + // TODO: Rename s/Length/NormalizedSquaredLength/ float getPointToKeyLength(const int inputIndex, const int codePoint) const; ProximityType getProximityType(const int index, const int codePoint, @@ -231,7 +233,7 @@ class ProximityInfoState { std::vector mSampledInputIndice; std::vector mSampledLengthCache; std::vector mBeelineSpeedPercentiles; - std::vector mSampledDistanceCache_G; + std::vector mSampledNormalizedSquaredLengthCache; std::vector mSpeedRates; std::vector mDirections; // probabilities of skipping or mapping to a key for each point. diff --git a/native/jni/src/proximity_info_state_utils.cpp b/native/jni/src/proximity_info_state_utils.cpp index df70cffdf..359673cd8 100644 --- a/native/jni/src/proximity_info_state_utils.cpp +++ b/native/jni/src/proximity_info_state_utils.cpp @@ -225,13 +225,13 @@ namespace latinime { const int lastSavedInputSize, const float verticalSweetSpotScale, const std::vector *const sampledInputXs, const std::vector *const sampledInputYs, - std::vector *SampledNearKeySets, - std::vector *SampledDistanceCache_G) { - SampledNearKeySets->resize(sampledInputSize); + std::vector *sampledNearKeySets, + std::vector *sampledNormalizedSquaredLengthCache) { + sampledNearKeySets->resize(sampledInputSize); const int keyCount = proximityInfo->getKeyCount(); - SampledDistanceCache_G->resize(sampledInputSize * keyCount); + sampledNormalizedSquaredLengthCache->resize(sampledInputSize * keyCount); for (int i = lastSavedInputSize; i < sampledInputSize; ++i) { - (*SampledNearKeySets)[i].reset(); + (*sampledNearKeySets)[i].reset(); for (int k = 0; k < keyCount; ++k) { const int index = i * keyCount + k; const int x = (*sampledInputXs)[i]; @@ -239,10 +239,10 @@ namespace latinime { const float normalizedSquaredDistance = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG( k, x, y, verticalSweetSpotScale); - (*SampledDistanceCache_G)[index] = normalizedSquaredDistance; + (*sampledNormalizedSquaredLengthCache)[index] = normalizedSquaredDistance; if (normalizedSquaredDistance < ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { - (*SampledNearKeySets)[i][k] = true; + (*sampledNearKeySets)[i][k] = true; } } } @@ -642,11 +642,11 @@ namespace latinime { // This function basically converts from a length to an edit distance. Accordingly, it's obviously // wrong to compare with mMaxPointToKeyLength. /* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength, - const std::vector *const SampledDistanceCache_G, const int keyCount, + const std::vector *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId) { if (keyId != NOT_AN_INDEX) { const int index = inputIndex * keyCount + keyId; - return min((*SampledDistanceCache_G)[index], maxPointToKeyLength); + return min((*sampledNormalizedSquaredLengthCache)[index], maxPointToKeyLength); } // If the char is not a key on the keyboard then return the max length. return static_cast(MAX_VALUE_FOR_WEIGHTING); @@ -660,8 +660,8 @@ namespace latinime { const std::vector *const sampledInputYs, const std::vector *const sampledSpeedRates, const std::vector *const sampledLengthCache, - const std::vector *const SampledDistanceCache_G, - std::vector *SampledNearKeySets, + const std::vector *const sampledNormalizedSquaredLengthCache, + std::vector *sampledNearKeySets, std::vector > *charProbabilities) { charProbabilities->resize(sampledInputSize); // Calculates probabilities of using a point as a correlated point with the character @@ -677,9 +677,9 @@ namespace latinime { float nearestKeyDistance = static_cast(MAX_VALUE_FOR_WEIGHTING); for (int j = 0; j < keyCount; ++j) { - if ((*SampledNearKeySets)[i].test(j)) { + if ((*sampledNearKeySets)[i].test(j)) { const float distance = getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j); if (distance < nearestKeyDistance) { nearestKeyDistance = distance; } @@ -758,14 +758,15 @@ namespace latinime { // Summing up probability densities of all near keys. float sumOfProbabilityDensities = 0.0f; for (int j = 0; j < keyCount; ++j) { - if ((*SampledNearKeySets)[i].test(j)) { + if ((*sampledNearKeySets)[i].test(j)) { float distance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j)); if (i == 0 && i != sampledInputSize - 1) { // For the first point, weighted average of distances from first point and the // next point to the key is used as a point to key distance. const float nextDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i + 1, j)); if (nextDistance < distance) { // The distance of the first point tends to bigger than continuing // points because the first touch by the user can be sloppy. @@ -779,7 +780,8 @@ namespace latinime { // For the first point, weighted average of distances from last point and // the previous point to the key is used as a point to key distance. const float previousDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i - 1, j)); if (previousDistance < distance) { // The distance of the last point tends to bigger than continuing points // because the last touch by the user can be sloppy. So we promote the @@ -798,14 +800,15 @@ namespace latinime { // Split the probability of an input point to keys that are close to the input point. for (int j = 0; j < keyCount; ++j) { - if ((*SampledNearKeySets)[i].test(j)) { + if ((*sampledNearKeySets)[i].test(j)) { float distance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j)); if (i == 0 && i != sampledInputSize - 1) { // For the first point, weighted average of distances from the first point and // the next point to the key is used as a point to key distance. const float prevDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i + 1, j)); if (prevDistance < distance) { distance = (distance + prevDistance * ProximityInfoParams::NEXT_DISTANCE_WEIGHT) @@ -815,7 +818,8 @@ namespace latinime { // For the first point, weighted average of distances from last point and // the previous point to the key is used as a point to key distance. const float prevDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i - 1, j)); if (prevDistance < distance) { distance = (distance + prevDistance * ProximityInfoParams::PREV_DISTANCE_WEIGHT) @@ -882,10 +886,10 @@ namespace latinime { for (int j = 0; j < keyCount; ++j) { hash_map_compat::iterator it = (*charProbabilities)[i].find(j); if (it == (*charProbabilities)[i].end()){ - (*SampledNearKeySets)[i].reset(j); + (*sampledNearKeySets)[i].reset(j); } else if(it->second < ProximityInfoParams::MIN_PROBABILITY) { // Erases from near keys vector because it has very low probability. - (*SampledNearKeySets)[i].reset(j); + (*sampledNearKeySets)[i].reset(j); (*charProbabilities)[i].erase(j); } else { it->second = -logf(it->second); @@ -899,7 +903,7 @@ namespace latinime { const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector *const sampledLengthCache, - const std::vector *const SampledNearKeySets, + const std::vector *const sampledNearKeySets, std::vector *sampledSearchKeySets, std::vector > *sampledSearchKeyVectors) { sampledSearchKeySets->resize(sampledInputSize); @@ -916,7 +920,7 @@ namespace latinime { if ((*sampledLengthCache)[j] - (*sampledLengthCache)[i] >= readForwordLength) { break; } - (*sampledSearchKeySets)[i] |= (*SampledNearKeySets)[j]; + (*sampledSearchKeySets)[i] |= (*sampledNearKeySets)[j]; } } const int keyCount = proximityInfo->getKeyCount(); diff --git a/native/jni/src/proximity_info_state_utils.h b/native/jni/src/proximity_info_state_utils.h index c9feb59a3..1837c7ab6 100644 --- a/native/jni/src/proximity_info_state_utils.h +++ b/native/jni/src/proximity_info_state_utils.h @@ -71,25 +71,25 @@ class ProximityInfoStateUtils { const std::vector *const sampledInputYs, const std::vector *const sampledSpeedRates, const std::vector *const sampledLengthCache, - const std::vector *const SampledDistanceCache_G, - std::vector *SampledNearKeySets, + const std::vector *const sampledNormalizedSquaredLengthCache, + std::vector *sampledNearKeySets, std::vector > *charProbabilities); static void updateSampledSearchKeySets(const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector *const sampledLengthCache, - const std::vector *const SampledNearKeySets, + const std::vector *const sampledNearKeySets, std::vector *sampledSearchKeySets, std::vector > *sampledSearchKeyVectors); static float getPointToKeyByIdLength(const float maxPointToKeyLength, - const std::vector *const SampledDistanceCache_G, const int keyCount, + const std::vector *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId); static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const float verticalSweetSpotScale, const std::vector *const sampledInputXs, const std::vector *const sampledInputYs, - std::vector *SampledNearKeySets, - std::vector *SampledDistanceCache_G); + std::vector *sampledNearKeySets, + std::vector *sampledNormalizedSquaredLengthCache); static void initPrimaryInputWord(const int inputSize, const int *const inputProximities, int *primaryInputWord); static void initNormalizedSquaredDistances(const ProximityInfo *const proximityInfo, diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index a7c042ada..fe0527639 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -146,6 +146,10 @@ class DicTraverseSession { return true; } + bool isTouchPositionCorrectionEnabled() const { + return mProximityInfoStates[0].touchPositionCorrectionEnabled(); + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); // threshold to start caching diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 52d54eb0f..2dcee343f 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -18,6 +18,7 @@ #define LATINIME_TYPING_WEIGHTING_H #include "defines.h" +#include "suggest_utils.h" #include "suggest/core/dicnode/dic_node_utils.h" #include "suggest/core/policy/weighting.h" #include "suggest/core/session/dic_traverse_session.h" @@ -70,10 +71,12 @@ class TypingWeighting : public Weighting { const int pointIndex = dicNode->getInputIndex(0); // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on // the keyboard (like accented letters) - const float length = min(ScoringParams::MAX_SPATIAL_DISTANCE, - traverseSession->getProximityInfoState(0)->getPointToKeyLength( - pointIndex, dicNode->getNodeCodePoint())); - const float weightedDistance = length * ScoringParams::DISTANCE_WEIGHT_LENGTH; + const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) + ->getPointToKeyLength(pointIndex, dicNode->getNodeCodePoint()); + const float normalizedDistance = SuggestUtils::getSweetSpotFactor( + traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); + const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; + const bool isFirstChar = pointIndex == 0; const bool isProximity = isProximityDicNode(traverseSession, dicNode); const float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST diff --git a/native/jni/src/suggest_utils.h b/native/jni/src/suggest_utils.h index aab9f7ba8..e053dd662 100644 --- a/native/jni/src/suggest_utils.h +++ b/native/jni/src/suggest_utils.h @@ -23,10 +23,8 @@ namespace latinime { class SuggestUtils { public: - static float getDistanceScalingFactor(const float normalizedSquaredDistance) { - if (normalizedSquaredDistance < 0.0f) { - return -1.0f; - } + // TODO: (OLD) Remove + static float getLengthScalingFactor(const float normalizedSquaredDistance) { // Promote or demote the score according to the distance from the sweet spot static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f; static const float B = 1.0f; @@ -50,6 +48,39 @@ class SuggestUtils { return factor; } + static float getSweetSpotFactor(const bool isTouchPositionCorrectionEnabled, + const float normalizedSquaredDistance) { + // Promote or demote the score according to the distance from the sweet spot + static const float A = 0.0f; + static const float B = 0.24f; + static const float C = 1.20f; + static const float R0 = 0.0f; + static const float R1 = 0.25f; // Sweet spot + static const float R2 = 1.0f; + const float x = normalizedSquaredDistance; + if (!isTouchPositionCorrectionEnabled) { + return min(C, x); + } + + // factor is a piecewise linear function like: + // C -------------. + // / . + // B / . + // -/ . + // A _-^ . + // . + // R0 R1 R2 . + + if (x < R0) { + return A; + } else if (x < R1) { + return (A * (R1 - x) + B * (x - R0)) / (R1 - R0); + } else if (x < R2) { + return (B * (R2 - x) + C * (x - R1)) / (R2 - R1); + } else { + return C; + } + } private: DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestUtils); };