refactor distance cache

Change-Id: I21b54b356641a63d7be17fd34b9ede7a63ec738a
main
Satoshi Kataoka 2013-01-15 19:16:38 +09:00
parent 6cee61deeb
commit a9763f93d7
4 changed files with 66 additions and 30 deletions

View File

@ -22,6 +22,7 @@
#include "correction.h" #include "correction.h"
#include "defines.h" #include "defines.h"
#include "proximity_info_state.h" #include "proximity_info_state.h"
#include "suggest_utils.h"
namespace latinime { namespace latinime {
@ -673,27 +674,9 @@ inline static bool isUpperCase(unsigned short c) {
if (i < adjustedProximityMatchedCount) { if (i < adjustedProximityMatchedCount) {
multiplyIntCapped(typedLetterMultiplier, &finalFreq); multiplyIntCapped(typedLetterMultiplier, &finalFreq);
} }
if (squaredDistance >= 0) { const float factor =
// Promote or demote the score according to the distance from the sweet spot SuggestUtils::getDistanceScalingFactor(static_cast<float>(squaredDistance));
static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f; if (factor > 0.0f) {
static const float B = 1.0f;
static const float C = 0.5f;
static const float MIN = 0.3f;
static const float R1 = NEUTRAL_SCORE_SQUARED_RADIUS;
static const float R2 = HALF_SCORE_SQUARED_RADIUS;
const float x = static_cast<float>(squaredDistance)
/ ProximityInfoState::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR;
const float factor = max((x < R1)
? (A * (R1 - x) + B * x) / R1
: (B * (R2 - x) + C * (x - R1)) / (R2 - R1), MIN);
// factor is a piecewise linear function like:
// A -_ .
// ^-_ .
// B \ .
// \_ .
// C ------------.
// .
// 0 R1 R2 .
multiplyRate((int)(factor * 100.0f), &finalFreq); multiplyRate((int)(factor * 100.0f), &finalFreq);
} else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) { } else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) {
multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq);

View File

@ -101,7 +101,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
mTimes.clear(); mTimes.clear();
mInputIndice.clear(); mInputIndice.clear();
mLengthCache.clear(); mLengthCache.clear();
mDistanceCache.clear(); mDistanceCache_G.clear();
mNearKeysVector.clear(); mNearKeysVector.clear();
mSearchKeysVector.clear(); mSearchKeysVector.clear();
mSpeedRates.clear(); mSpeedRates.clear();
@ -210,7 +210,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
const int keyCount = mProximityInfo->getKeyCount(); const int keyCount = mProximityInfo->getKeyCount();
mNearKeysVector.resize(mSampledInputSize); mNearKeysVector.resize(mSampledInputSize);
mSearchKeysVector.resize(mSampledInputSize); mSearchKeysVector.resize(mSampledInputSize);
mDistanceCache.resize(mSampledInputSize * keyCount); mDistanceCache_G.resize(mSampledInputSize * keyCount);
for (int i = lastSavedInputSize; i < mSampledInputSize; ++i) { for (int i = lastSavedInputSize; i < mSampledInputSize; ++i) {
mNearKeysVector[i].reset(); mNearKeysVector[i].reset();
mSearchKeysVector[i].reset(); mSearchKeysVector[i].reset();
@ -221,7 +221,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
const int y = mSampledInputYs[i]; const int y = mSampledInputYs[i];
const float normalizedSquaredDistance = const float normalizedSquaredDistance =
mProximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y); mProximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y);
mDistanceCache[index] = normalizedSquaredDistance; mDistanceCache_G[index] = normalizedSquaredDistance;
if (normalizedSquaredDistance < NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { if (normalizedSquaredDistance < NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) {
mNearKeysVector[i][k] = true; mNearKeysVector[i][k] = true;
} }
@ -686,7 +686,7 @@ float ProximityInfoState::getPointToKeyLength(
const int keyId = mProximityInfo->getKeyIndexOf(codePoint); const int keyId = mProximityInfo->getKeyIndexOf(codePoint);
if (keyId != NOT_AN_INDEX) { if (keyId != NOT_AN_INDEX) {
const int index = inputIndex * mProximityInfo->getKeyCount() + keyId; const int index = inputIndex * mProximityInfo->getKeyCount() + keyId;
return min(mDistanceCache[index] * scale, mMaxPointToKeyLength); return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength);
} }
if (isSkippableCodePoint(codePoint)) { if (isSkippableCodePoint(codePoint)) {
return 0.0f; return 0.0f;
@ -695,7 +695,7 @@ float ProximityInfoState::getPointToKeyLength(
return MAX_POINT_TO_KEY_LENGTH; return MAX_POINT_TO_KEY_LENGTH;
} }
float ProximityInfoState::getPointToKeyLength(const int inputIndex, const int codePoint) const { float ProximityInfoState::getPointToKeyLength_G(const int inputIndex, const int codePoint) const {
return getPointToKeyLength(inputIndex, codePoint, 1.0f); return getPointToKeyLength(inputIndex, codePoint, 1.0f);
} }
@ -706,7 +706,7 @@ float ProximityInfoState::getPointToKeyByIdLength(
const int inputIndex, const int keyId, const float scale) const { const int inputIndex, const int keyId, const float scale) const {
if (keyId != NOT_AN_INDEX) { if (keyId != NOT_AN_INDEX) {
const int index = inputIndex * mProximityInfo->getKeyCount() + keyId; const int index = inputIndex * mProximityInfo->getKeyCount() + keyId;
return min(mDistanceCache[index] * scale, mMaxPointToKeyLength); return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength);
} }
// If the char is not a key on the keyboard then return the max length. // If the char is not a key on the keyboard then return the max length.
return static_cast<float>(MAX_POINT_TO_KEY_LENGTH); return static_cast<float>(MAX_POINT_TO_KEY_LENGTH);

View File

@ -58,7 +58,7 @@ class ProximityInfoState {
mHasTouchPositionCorrectionData(false), mMostCommonKeyWidthSquare(0), mLocaleStr(), mHasTouchPositionCorrectionData(false), mMostCommonKeyWidthSquare(0), mLocaleStr(),
mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0), mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0),
mIsContinuationPossible(false), mSampledInputXs(), mSampledInputYs(), mTimes(), mIsContinuationPossible(false), mSampledInputXs(), mSampledInputYs(), mTimes(),
mInputIndice(), mLengthCache(), mBeelineSpeedPercentiles(), mDistanceCache(), mInputIndice(), mLengthCache(), mBeelineSpeedPercentiles(), mDistanceCache_G(),
mSpeedRates(), mDirections(), mCharProbabilities(), mNearKeysVector(), mSpeedRates(), mDirections(), mCharProbabilities(), mNearKeysVector(),
mSearchKeysVector(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0) { mSearchKeysVector(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0) {
memset(mInputCodes, 0, sizeof(mInputCodes)); memset(mInputCodes, 0, sizeof(mInputCodes));
@ -157,7 +157,7 @@ class ProximityInfoState {
float getPointToKeyByIdLength(const int inputIndex, const int keyId, const float scale) const; float getPointToKeyByIdLength(const int inputIndex, const int keyId, const float scale) const;
float getPointToKeyByIdLength(const int inputIndex, const int keyId) const; float getPointToKeyByIdLength(const int inputIndex, const int keyId) const;
float getPointToKeyLength(const int inputIndex, const int codePoint, const float scale) const; float getPointToKeyLength(const int inputIndex, const int codePoint, const float scale) const;
float getPointToKeyLength(const int inputIndex, const int codePoint) const; float getPointToKeyLength_G(const int inputIndex, const int codePoint) const;
ProximityType getMatchedProximityId(const int index, const int c, ProximityType getMatchedProximityId(const int index, const int c,
const bool checkProximityChars, int *proximityIndex = 0) const; const bool checkProximityChars, int *proximityIndex = 0) const;
@ -274,7 +274,7 @@ class ProximityInfoState {
std::vector<int> mInputIndice; std::vector<int> mInputIndice;
std::vector<int> mLengthCache; std::vector<int> mLengthCache;
std::vector<int> mBeelineSpeedPercentiles; std::vector<int> mBeelineSpeedPercentiles;
std::vector<float> mDistanceCache; std::vector<float> mDistanceCache_G;
std::vector<float> mSpeedRates; std::vector<float> mSpeedRates;
std::vector<float> mDirections; std::vector<float> mDirections;
// probabilities of skipping or mapping to a key for each point. // probabilities of skipping or mapping to a key for each point.

View File

@ -0,0 +1,53 @@
/*
* Copyright (C) 2013 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_SUGGEST_UTILS_H
#define LATINIME_SUGGEST_UTILS_H
#include "defines.h"
namespace latinime {
class SuggestUtils {
public:
static float getDistanceScalingFactor(float normalizedSquaredDistance) {
if (normalizedSquaredDistance < 0.0f) {
return -1.0f;
}
// Promote or demote the score according to the distance from the sweet spot
static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f;
static const float B = 1.0f;
static const float C = 0.5f;
static const float MIN = 0.3f;
static const float R1 = NEUTRAL_SCORE_SQUARED_RADIUS;
static const float R2 = HALF_SCORE_SQUARED_RADIUS;
const float x = static_cast<float>(normalizedSquaredDistance)
/ ProximityInfoState::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR;
const float factor = max((x < R1)
? (A * (R1 - x) + B * x) / R1
: (B * (R2 - x) + C * (x - R1)) / (R2 - R1), MIN);
// factor is a piecewise linear function like:
// A -_ .
// ^-_ .
// B \ .
// \_ .
// C ------------.
// .
// 0 R1 R2 .
return factor;
}
};
} // namespace latinime
#endif // LATINIME_SUGGEST_UTILS_H