am d4828d50: Refactor proximity info state

* commit 'd4828d5053ac30476b884c177235be0cac982c92':
  Refactor proximity info state
This commit is contained in:
Satoshi Kataoka 2013-01-21 22:46:09 -08:00 committed by Android Git Automerger
commit 41fcc80e14
4 changed files with 409 additions and 348 deletions

View file

@ -138,7 +138,11 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
} }
if (isGeometric) { if (isGeometric) {
// updates probabilities of skipping or mapping each key for all points. // updates probabilities of skipping or mapping each key for all points.
updateAlignPointProbabilities(lastSavedInputSize); ProximityInfoStateUtils::updateAlignPointProbabilities(
mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(),
keyCount, lastSavedInputSize, mSampledInputSize, &mSampledInputXs,
&mSampledInputYs, &mSpeedRates, &mLengthCache, &mDistanceCache_G,
&mNearKeysVector, &mCharProbabilities);
static const float READ_FORWORD_LENGTH_SCALE = 0.95f; static const float READ_FORWORD_LENGTH_SCALE = 0.95f;
const int readForwordLength = static_cast<int>( const int readForwordLength = static_cast<int>(
@ -307,16 +311,10 @@ float ProximityInfoState::getPointToKeyLength_G(const int inputIndex, const int
} }
// TODO: Remove the "scale" parameter // TODO: Remove the "scale" parameter
// This function basically converts from a length to an edit distance. Accordingly, it's obviously
// wrong to compare with mMaxPointToKeyLength.
float ProximityInfoState::getPointToKeyByIdLength( float ProximityInfoState::getPointToKeyByIdLength(
const int inputIndex, const int keyId, const float scale) const { const int inputIndex, const int keyId, const float scale) const {
if (keyId != NOT_AN_INDEX) { return ProximityInfoStateUtils::getPointToKeyByIdLength(mMaxPointToKeyLength,
const int index = inputIndex * mProximityInfo->getKeyCount() + keyId; &mDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId, scale);
return min(mDistanceCache_G[index] * scale, mMaxPointToKeyLength);
}
// If the char is not a key on the keyboard then return the max length.
return static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
} }
float ProximityInfoState::getPointToKeyByIdLength(const int inputIndex, const int keyId) const { float ProximityInfoState::getPointToKeyByIdLength(const int inputIndex, const int keyId) const {
@ -442,32 +440,6 @@ float ProximityInfoState::getDirection(const int index0, const int index1) const
&mSampledInputXs, &mSampledInputYs, index0, index1); &mSampledInputXs, &mSampledInputYs, index0, index1);
} }
float ProximityInfoState::getPointAngle(const int index) const {
if (index <= 0 || index >= mSampledInputSize - 1) {
return 0.0f;
}
const float previousDirection = getDirection(index - 1, index);
const float nextDirection = getDirection(index, index + 1);
const float directionDiff = getAngleDiff(previousDirection, nextDirection);
return directionDiff;
}
float ProximityInfoState::getPointsAngle(
const int index0, const int index1, const int index2) const {
if (index0 < 0 || index0 > mSampledInputSize - 1) {
return 0.0f;
}
if (index1 < 0 || index1 > mSampledInputSize - 1) {
return 0.0f;
}
if (index2 < 0 || index2 > mSampledInputSize - 1) {
return 0.0f;
}
const float previousDirection = getDirection(index0, index1);
const float nextDirection = getDirection(index1, index2);
return getAngleDiff(previousDirection, nextDirection);
}
float ProximityInfoState::getLineToKeyDistance( float ProximityInfoState::getLineToKeyDistance(
const int from, const int to, const int keyId, const bool extend) const { const int from, const int to, const int keyId, const bool extend) const {
if (from < 0 || from > mSampledInputSize - 1) { if (from < 0 || from > mSampledInputSize - 1) {
@ -488,293 +460,6 @@ float ProximityInfoState::getLineToKeyDistance(
keyX, keyY, x0, y0, x1, y1, extend); keyX, keyY, x0, y0, x1, y1, extend);
} }
// Updates probabilities of aligning to some keys and skipping.
// Word suggestion should be based on this probabilities.
void ProximityInfoState::updateAlignPointProbabilities(const int start) {
static const float MIN_PROBABILITY = 0.000001f;
static const float MAX_SKIP_PROBABILITY = 0.95f;
static const float SKIP_FIRST_POINT_PROBABILITY = 0.01f;
static const float SKIP_LAST_POINT_PROBABILITY = 0.1f;
static const float MIN_SPEED_RATE_FOR_SKIP_PROBABILITY = 0.15f;
static const float SPEED_WEIGHT_FOR_SKIP_PROBABILITY = 0.9f;
static const float SLOW_STRAIGHT_WEIGHT_FOR_SKIP_PROBABILITY = 0.6f;
static const float NEAREST_DISTANCE_WEIGHT = 0.5f;
static const float NEAREST_DISTANCE_BIAS = 0.5f;
static const float NEAREST_DISTANCE_WEIGHT_FOR_LAST = 0.6f;
static const float NEAREST_DISTANCE_BIAS_FOR_LAST = 0.4f;
static const float ANGLE_WEIGHT = 0.90f;
static const float DEEP_CORNER_ANGLE_THRESHOLD = M_PI_F * 60.0f / 180.0f;
static const float SKIP_DEEP_CORNER_PROBABILITY = 0.1f;
static const float CORNER_ANGLE_THRESHOLD = M_PI_F * 30.0f / 180.0f;
static const float STRAIGHT_ANGLE_THRESHOLD = M_PI_F * 15.0f / 180.0f;
static const float SKIP_CORNER_PROBABILITY = 0.4f;
static const float SPEED_MARGIN = 0.1f;
static const float CENTER_VALUE_OF_NORMALIZED_DISTRIBUTION = 0.0f;
const int keyCount = mProximityInfo->getKeyCount();
mCharProbabilities.resize(mSampledInputSize);
// Calculates probabilities of using a point as a correlated point with the character
// for each point.
for (int i = start; i < mSampledInputSize; ++i) {
mCharProbabilities[i].clear();
// First, calculates skip probability. Starts form MIN_SKIP_PROBABILITY.
// Note that all values that are multiplied to this probability should be in [0.0, 1.0];
float skipProbability = MAX_SKIP_PROBABILITY;
const float currentAngle = getPointAngle(i);
const float speedRate = getSpeedRate(i);
float nearestKeyDistance = static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
for (int j = 0; j < keyCount; ++j) {
if (mNearKeysVector[i].test(j)) {
const float distance = getPointToKeyByIdLength(i, j);
if (distance < nearestKeyDistance) {
nearestKeyDistance = distance;
}
}
}
if (i == 0) {
skipProbability *= min(1.0f, nearestKeyDistance * NEAREST_DISTANCE_WEIGHT
+ NEAREST_DISTANCE_BIAS);
// Promote the first point
skipProbability *= SKIP_FIRST_POINT_PROBABILITY;
} else if (i == mSampledInputSize - 1) {
skipProbability *= min(1.0f, nearestKeyDistance * NEAREST_DISTANCE_WEIGHT_FOR_LAST
+ NEAREST_DISTANCE_BIAS_FOR_LAST);
// Promote the last point
skipProbability *= SKIP_LAST_POINT_PROBABILITY;
} else {
// If the current speed is relatively slower than adjacent keys, we promote this point.
if (getSpeedRate(i - 1) - SPEED_MARGIN > speedRate
&& speedRate < getSpeedRate(i + 1) - SPEED_MARGIN) {
if (currentAngle < CORNER_ANGLE_THRESHOLD) {
skipProbability *= min(1.0f, speedRate
* SLOW_STRAIGHT_WEIGHT_FOR_SKIP_PROBABILITY);
} else {
// If the angle is small enough, we promote this point more. (e.g. pit vs put)
skipProbability *= min(1.0f, speedRate * SPEED_WEIGHT_FOR_SKIP_PROBABILITY
+ MIN_SPEED_RATE_FOR_SKIP_PROBABILITY);
}
}
skipProbability *= min(1.0f, speedRate * nearestKeyDistance *
NEAREST_DISTANCE_WEIGHT + NEAREST_DISTANCE_BIAS);
// Adjusts skip probability by a rate depending on angle.
// ANGLE_RATE of skipProbability is adjusted by current angle.
skipProbability *= (M_PI_F - currentAngle) / M_PI_F * ANGLE_WEIGHT
+ (1.0f - ANGLE_WEIGHT);
if (currentAngle > DEEP_CORNER_ANGLE_THRESHOLD) {
skipProbability *= SKIP_DEEP_CORNER_PROBABILITY;
}
// We assume the angle of this point is the angle for point[i], point[i - 2]
// and point[i - 3]. The reason why we don't use the angle for point[i], point[i - 1]
// and point[i - 2] is this angle can be more affected by the noise.
const float prevAngle = getPointsAngle(i, i - 2, i - 3);
if (i >= 3 && prevAngle < STRAIGHT_ANGLE_THRESHOLD
&& currentAngle > CORNER_ANGLE_THRESHOLD) {
skipProbability *= SKIP_CORNER_PROBABILITY;
}
}
// probabilities must be in [0.0, MAX_SKIP_PROBABILITY];
ASSERT(skipProbability >= 0.0f);
ASSERT(skipProbability <= MAX_SKIP_PROBABILITY);
mCharProbabilities[i][NOT_AN_INDEX] = skipProbability;
// Second, calculates key probabilities by dividing the rest probability
// (1.0f - skipProbability).
const float inputCharProbability = 1.0f - skipProbability;
// TODO: The variance is critical for accuracy; thus, adjusting these parameter by machine
// learning or something would be efficient.
static const float SPEEDxANGLE_WEIGHT_FOR_STANDARD_DIVIATION = 0.3f;
static const float MAX_SPEEDxANGLE_RATE_FOR_STANDERD_DIVIATION = 0.25f;
static const float SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DIVIATION = 0.5f;
static const float MAX_SPEEDxNEAREST_RATE_FOR_STANDERD_DIVIATION = 0.15f;
static const float MIN_STANDERD_DIVIATION = 0.37f;
const float speedxAngleRate = min(speedRate * currentAngle / M_PI_F
* SPEEDxANGLE_WEIGHT_FOR_STANDARD_DIVIATION,
MAX_SPEEDxANGLE_RATE_FOR_STANDERD_DIVIATION);
const float speedxNearestKeyDistanceRate = min(speedRate * nearestKeyDistance
* SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DIVIATION,
MAX_SPEEDxNEAREST_RATE_FOR_STANDERD_DIVIATION);
const float sigma = speedxAngleRate + speedxNearestKeyDistanceRate + MIN_STANDERD_DIVIATION;
ProximityInfoUtils::NormalDistribution
distribution(CENTER_VALUE_OF_NORMALIZED_DISTRIBUTION, sigma);
static const float PREV_DISTANCE_WEIGHT = 0.5f;
static const float NEXT_DISTANCE_WEIGHT = 0.6f;
// Summing up probability densities of all near keys.
float sumOfProbabilityDensities = 0.0f;
for (int j = 0; j < keyCount; ++j) {
if (mNearKeysVector[i].test(j)) {
float distance = sqrtf(getPointToKeyByIdLength(i, j));
if (i == 0 && i != mSampledInputSize - 1) {
// For the first point, weighted average of distances from first point and the
// next point to the key is used as a point to key distance.
const float nextDistance = sqrtf(getPointToKeyByIdLength(i + 1, j));
if (nextDistance < distance) {
// The distance of the first point tends to bigger than continuing
// points because the first touch by the user can be sloppy.
// So we promote the first point if the distance of that point is larger
// than the distance of the next point.
distance = (distance + nextDistance * NEXT_DISTANCE_WEIGHT)
/ (1.0f + NEXT_DISTANCE_WEIGHT);
}
} else if (i != 0 && i == mSampledInputSize - 1) {
// For the first point, weighted average of distances from last point and
// the previous point to the key is used as a point to key distance.
const float previousDistance = sqrtf(getPointToKeyByIdLength(i - 1, j));
if (previousDistance < distance) {
// The distance of the last point tends to bigger than continuing points
// because the last touch by the user can be sloppy. So we promote the
// last point if the distance of that point is larger than the distance of
// the previous point.
distance = (distance + previousDistance * PREV_DISTANCE_WEIGHT)
/ (1.0f + PREV_DISTANCE_WEIGHT);
}
}
// TODO: Promote the first point when the extended line from the next input is near
// from a key. Also, promote the last point as well.
sumOfProbabilityDensities += distribution.getProbabilityDensity(distance);
}
}
// Split the probability of an input point to keys that are close to the input point.
for (int j = 0; j < keyCount; ++j) {
if (mNearKeysVector[i].test(j)) {
float distance = sqrtf(getPointToKeyByIdLength(i, j));
if (i == 0 && i != mSampledInputSize - 1) {
// For the first point, weighted average of distances from the first point and
// the next point to the key is used as a point to key distance.
const float prevDistance = sqrtf(getPointToKeyByIdLength(i + 1, j));
if (prevDistance < distance) {
distance = (distance + prevDistance * NEXT_DISTANCE_WEIGHT)
/ (1.0f + NEXT_DISTANCE_WEIGHT);
}
} else if (i != 0 && i == mSampledInputSize - 1) {
// For the first point, weighted average of distances from last point and
// the previous point to the key is used as a point to key distance.
const float prevDistance = sqrtf(getPointToKeyByIdLength(i - 1, j));
if (prevDistance < distance) {
distance = (distance + prevDistance * PREV_DISTANCE_WEIGHT)
/ (1.0f + PREV_DISTANCE_WEIGHT);
}
}
const float probabilityDensity = distribution.getProbabilityDensity(distance);
const float probability = inputCharProbability * probabilityDensity
/ sumOfProbabilityDensities;
mCharProbabilities[i][j] = probability;
}
}
}
if (DEBUG_POINTS_PROBABILITY) {
for (int i = 0; i < mSampledInputSize; ++i) {
std::stringstream sstream;
sstream << i << ", ";
sstream << "(" << mSampledInputXs[i] << ", " << mSampledInputYs[i] << "), ";
sstream << "Speed: "<< getSpeedRate(i) << ", ";
sstream << "Angle: "<< getPointAngle(i) << ", \n";
for (hash_map_compat<int, float>::iterator it = mCharProbabilities[i].begin();
it != mCharProbabilities[i].end(); ++it) {
if (it->first == NOT_AN_INDEX) {
sstream << it->first
<< "(skip):"
<< it->second
<< "\n";
} else {
sstream << it->first
<< "("
<< static_cast<char>(mProximityInfo->getCodePointOf(it->first))
<< "):"
<< it->second
<< "\n";
}
}
AKLOGI("%s", sstream.str().c_str());
}
}
// Decrease key probabilities of points which don't have the highest probability of that key
// among nearby points. Probabilities of the first point and the last point are not suppressed.
for (int i = max(start, 1); i < mSampledInputSize; ++i) {
for (int j = i + 1; j < mSampledInputSize; ++j) {
if (!suppressCharProbabilities(i, j)) {
break;
}
}
for (int j = i - 1; j >= max(start, 0); --j) {
if (!suppressCharProbabilities(i, j)) {
break;
}
}
}
// Converting from raw probabilities to log probabilities to calculate spatial distance.
for (int i = start; i < mSampledInputSize; ++i) {
for (int j = 0; j < keyCount; ++j) {
hash_map_compat<int, float>::iterator it = mCharProbabilities[i].find(j);
if (it == mCharProbabilities[i].end()){
mNearKeysVector[i].reset(j);
} else if(it->second < MIN_PROBABILITY) {
// Erases from near keys vector because it has very low probability.
mNearKeysVector[i].reset(j);
mCharProbabilities[i].erase(j);
} else {
it->second = -logf(it->second);
}
}
mCharProbabilities[i][NOT_AN_INDEX] = -logf(mCharProbabilities[i][NOT_AN_INDEX]);
}
}
// Decreases char probabilities of index0 by checking probabilities of a near point (index1) and
// increases char probabilities of index1 by checking probabilities of index0.
bool ProximityInfoState::suppressCharProbabilities(const int index0, const int index1) {
ASSERT(0 <= index0 && index0 < mSampledInputSize);
ASSERT(0 <= index1 && index1 < mSampledInputSize);
static const float SUPPRESSION_LENGTH_WEIGHT = 1.5f;
static const float MIN_SUPPRESSION_RATE = 0.1f;
static const float SUPPRESSION_WEIGHT = 0.5f;
static const float SUPPRESSION_WEIGHT_FOR_PROBABILITY_GAIN = 0.1f;
static const float SKIP_PROBABALITY_WEIGHT_FOR_PROBABILITY_GAIN = 0.3f;
const float keyWidthFloat = static_cast<float>(mProximityInfo->getMostCommonKeyWidth());
const float diff = fabsf(static_cast<float>(mLengthCache[index0] - mLengthCache[index1]));
if (diff > keyWidthFloat * SUPPRESSION_LENGTH_WEIGHT) {
return false;
}
const float suppressionRate = MIN_SUPPRESSION_RATE
+ diff / keyWidthFloat / SUPPRESSION_LENGTH_WEIGHT * SUPPRESSION_WEIGHT;
for (hash_map_compat<int, float>::iterator it = mCharProbabilities[index0].begin();
it != mCharProbabilities[index0].end(); ++it) {
hash_map_compat<int, float>::iterator it2 = mCharProbabilities[index1].find(it->first);
if (it2 != mCharProbabilities[index1].end() && it->second < it2->second) {
const float newProbability = it->second * suppressionRate;
const float suppression = it->second - newProbability;
it->second = newProbability;
// mCharProbabilities[index0][NOT_AN_INDEX] is the probability of skipping this point.
mCharProbabilities[index0][NOT_AN_INDEX] += suppression;
// Add the probability of the same key nearby index1
const float probabilityGain = min(suppression * SUPPRESSION_WEIGHT_FOR_PROBABILITY_GAIN,
mCharProbabilities[index1][NOT_AN_INDEX]
* SKIP_PROBABALITY_WEIGHT_FOR_PROBABILITY_GAIN);
it2->second += probabilityGain;
mCharProbabilities[index1][NOT_AN_INDEX] -= probabilityGain;
}
}
return true;
}
// Get a word that is detected by tracing the most probable string into codePointBuf and // Get a word that is detected by tracing the most probable string into codePointBuf and
// returns probability of generating the word. // returns probability of generating the word.
float ProximityInfoState::getMostProbableString(int *const codePointBuf) const { float ProximityInfoState::getMostProbableString(int *const codePointBuf) const {

View file

@ -17,7 +17,6 @@
#ifndef LATINIME_PROXIMITY_INFO_STATE_H #ifndef LATINIME_PROXIMITY_INFO_STATE_H
#define LATINIME_PROXIMITY_INFO_STATE_H #define LATINIME_PROXIMITY_INFO_STATE_H
#include <bitset>
#include <cstring> // for memset() #include <cstring> // for memset()
#include <vector> #include <vector>
@ -33,7 +32,6 @@ class ProximityInfo;
class ProximityInfoState { class ProximityInfoState {
public: public:
typedef std::bitset<MAX_KEY_COUNT_IN_A_KEYBOARD> NearKeycodesSet;
static const int NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR_LOG_2; static const int NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR_LOG_2;
static const int NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR; static const int NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR;
static const float NOT_A_DISTANCE_FLOAT; static const float NOT_A_DISTANCE_FLOAT;
@ -191,10 +189,6 @@ class ProximityInfoState {
// get xy direction // get xy direction
float getDirection(const int x, const int y) const; float getDirection(const int x, const int y) const;
float getPointAngle(const int index) const;
// Returns angle of three points. x, y, and z are indices.
float getPointsAngle(const int index0, const int index1, const int index2) const;
float getMostProbableString(int *const codePointBuf) const; float getMostProbableString(int *const codePointBuf) const;
float getProbability(const int index, const int charCode) const; float getProbability(const int index, const int charCode) const;
@ -205,7 +199,6 @@ class ProximityInfoState {
bool isKeyInSerchKeysAfterIndex(const int index, const int keyId) const; bool isKeyInSerchKeysAfterIndex(const int index, const int keyId) const;
private: private:
DISALLOW_COPY_AND_ASSIGN(ProximityInfoState); DISALLOW_COPY_AND_ASSIGN(ProximityInfoState);
typedef hash_map_compat<int, float> NearKeysDistanceMap;
///////////////////////////////////////// /////////////////////////////////////////
// Defined in proximity_info_state.cpp // // Defined in proximity_info_state.cpp //
///////////////////////////////////////// /////////////////////////////////////////
@ -226,24 +219,9 @@ class ProximityInfoState {
inline const int *getProximityCodePointsAt(const int index) const { inline const int *getProximityCodePointsAt(const int index) const {
return ProximityInfoStateUtils::getProximityCodePointsAt(mInputProximities, index); return ProximityInfoStateUtils::getProximityCodePointsAt(mInputProximities, index);
} }
float updateNearKeysDistances(const int x, const int y,
NearKeysDistanceMap *const currentNearKeysDistances);
bool isPrevLocalMin(const NearKeysDistanceMap *const currentNearKeysDistances,
const NearKeysDistanceMap *const prevNearKeysDistances,
const NearKeysDistanceMap *const prevPrevNearKeysDistances) const;
float getPointScore(
const int x, const int y, const int time, const bool last, const float nearest,
const float sumAngle, const NearKeysDistanceMap *const currentNearKeysDistances,
const NearKeysDistanceMap *const prevNearKeysDistances,
const NearKeysDistanceMap *const prevPrevNearKeysDistances) const;
bool checkAndReturnIsContinuationPossible(const int inputSize, const int *const xCoordinates, bool checkAndReturnIsContinuationPossible(const int inputSize, const int *const xCoordinates,
const int *const yCoordinates, const int *const times, const bool isGeometric) const; const int *const yCoordinates, const int *const times, const bool isGeometric) const;
void popInputData(); void popInputData();
void updateAlignPointProbabilities(const int start);
bool suppressCharProbabilities(const int index1, const int index2);
float calculateBeelineSpeedRate(const int id, const int inputSize,
const int *const xCoordinates, const int *const yCoordinates, const int * times) const;
// const // const
const ProximityInfo *mProximityInfo; const ProximityInfo *mProximityInfo;
@ -272,12 +250,12 @@ class ProximityInfoState {
// The vector for the key code set which holds nearby keys for each sampled input point // The vector for the key code set which holds nearby keys for each sampled input point
// 1. Used to calculate the probability of the key // 1. Used to calculate the probability of the key
// 2. Used to calculate mSearchKeysVector // 2. Used to calculate mSearchKeysVector
std::vector<NearKeycodesSet> mNearKeysVector; std::vector<ProximityInfoStateUtils::NearKeycodesSet> mNearKeysVector;
// The vector for the key code set which holds nearby keys of some trailing sampled input points // The vector for the key code set which holds nearby keys of some trailing sampled input points
// for each sampled input point. These nearby keys contain the next characters which can be in // for each sampled input point. These nearby keys contain the next characters which can be in
// the dictionary. Specifically, currently we are looking for keys nearby trailing sampled // the dictionary. Specifically, currently we are looking for keys nearby trailing sampled
// inputs including the current input point. // inputs including the current input point.
std::vector<NearKeycodesSet> mSearchKeysVector; std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSearchKeysVector;
bool mTouchPositionCorrectionEnabled; bool mTouchPositionCorrectionEnabled;
int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];

View file

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include <sstream> // for debug prints
#include <vector> #include <vector>
#include "defines.h" #include "defines.h"
@ -481,4 +482,371 @@ namespace latinime {
// TODO: Detect double letter more smartly // TODO: Detect double letter more smartly
return 0.01f + static_cast<float>(beelineDistance) / static_cast<float>(time) / averageSpeed; return 0.01f + static_cast<float>(beelineDistance) / static_cast<float>(time) / averageSpeed;
} }
/* static */ float ProximityInfoStateUtils::getPointAngle(
const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs, const int index) {
if (!sampledInputXs || !sampledInputYs) {
return 0.0f;
}
const int sampledInputSize = sampledInputXs->size();
if (index <= 0 || index >= sampledInputSize - 1) {
return 0.0f;
}
const float previousDirection = getDirection(sampledInputXs, sampledInputYs, index - 1, index);
const float nextDirection = getDirection(sampledInputXs, sampledInputYs, index, index + 1);
const float directionDiff = getAngleDiff(previousDirection, nextDirection);
return directionDiff;
}
/* static */ float ProximityInfoStateUtils::getPointsAngle(
const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs,
const int index0, const int index1, const int index2) {
if (!sampledInputXs || !sampledInputYs) {
return 0.0f;
}
const int sampledInputSize = sampledInputXs->size();
if (index0 < 0 || index0 > sampledInputSize - 1) {
return 0.0f;
}
if (index1 < 0 || index1 > sampledInputSize - 1) {
return 0.0f;
}
if (index2 < 0 || index2 > sampledInputSize - 1) {
return 0.0f;
}
const float previousDirection = getDirection(sampledInputXs, sampledInputYs, index0, index1);
const float nextDirection = getDirection(sampledInputXs, sampledInputYs, index1, index2);
return getAngleDiff(previousDirection, nextDirection);
}
// TODO: Remove the "scale" parameter
// This function basically converts from a length to an edit distance. Accordingly, it's obviously
// wrong to compare with mMaxPointToKeyLength.
/* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength,
const std::vector<float> *const distanceCache_G, const int keyCount,
const int inputIndex, const int keyId, const float scale) {
if (keyId != NOT_AN_INDEX) {
const int index = inputIndex * keyCount + keyId;
return min((*distanceCache_G)[index] * scale, maxPointToKeyLength);
}
// If the char is not a key on the keyboard then return the max length.
return static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
}
/* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength,
const std::vector<float> *const distanceCache_G, const int keyCount,
const int inputIndex, const int keyId) {
return getPointToKeyByIdLength(maxPointToKeyLength, distanceCache_G, keyCount, inputIndex,
keyId, 1.0f);
}
// Updates probabilities of aligning to some keys and skipping.
// Word suggestion should be based on this probabilities.
/* static */ void ProximityInfoStateUtils::updateAlignPointProbabilities(
const float maxPointToKeyLength, const int mostCommonKeyWidth, const int keyCount,
const int start, const int sampledInputSize, const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs,
const std::vector<float> *const sampledSpeedRates,
const std::vector<int> *const sampledLengthCache,
const std::vector<float> *const distanceCache_G,
std::vector<NearKeycodesSet> *nearKeysVector,
std::vector<hash_map_compat<int, float> > *charProbabilities) {
static const float MIN_PROBABILITY = 0.000001f;
static const float MAX_SKIP_PROBABILITY = 0.95f;
static const float SKIP_FIRST_POINT_PROBABILITY = 0.01f;
static const float SKIP_LAST_POINT_PROBABILITY = 0.1f;
static const float MIN_SPEED_RATE_FOR_SKIP_PROBABILITY = 0.15f;
static const float SPEED_WEIGHT_FOR_SKIP_PROBABILITY = 0.9f;
static const float SLOW_STRAIGHT_WEIGHT_FOR_SKIP_PROBABILITY = 0.6f;
static const float NEAREST_DISTANCE_WEIGHT = 0.5f;
static const float NEAREST_DISTANCE_BIAS = 0.5f;
static const float NEAREST_DISTANCE_WEIGHT_FOR_LAST = 0.6f;
static const float NEAREST_DISTANCE_BIAS_FOR_LAST = 0.4f;
static const float ANGLE_WEIGHT = 0.90f;
static const float DEEP_CORNER_ANGLE_THRESHOLD = M_PI_F * 60.0f / 180.0f;
static const float SKIP_DEEP_CORNER_PROBABILITY = 0.1f;
static const float CORNER_ANGLE_THRESHOLD = M_PI_F * 30.0f / 180.0f;
static const float STRAIGHT_ANGLE_THRESHOLD = M_PI_F * 15.0f / 180.0f;
static const float SKIP_CORNER_PROBABILITY = 0.4f;
static const float SPEED_MARGIN = 0.1f;
static const float CENTER_VALUE_OF_NORMALIZED_DISTRIBUTION = 0.0f;
charProbabilities->resize(sampledInputSize);
// Calculates probabilities of using a point as a correlated point with the character
// for each point.
for (int i = start; i < sampledInputSize; ++i) {
(*charProbabilities)[i].clear();
// First, calculates skip probability. Starts form MIN_SKIP_PROBABILITY.
// Note that all values that are multiplied to this probability should be in [0.0, 1.0];
float skipProbability = MAX_SKIP_PROBABILITY;
const float currentAngle = getPointAngle(sampledInputXs, sampledInputYs, i);
const float speedRate = (*sampledSpeedRates)[i];
float nearestKeyDistance = static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
for (int j = 0; j < keyCount; ++j) {
if ((*nearKeysVector)[i].test(j)) {
const float distance = getPointToKeyByIdLength(
maxPointToKeyLength, distanceCache_G, keyCount, i, j);
if (distance < nearestKeyDistance) {
nearestKeyDistance = distance;
}
}
}
if (i == 0) {
skipProbability *= min(1.0f, nearestKeyDistance * NEAREST_DISTANCE_WEIGHT
+ NEAREST_DISTANCE_BIAS);
// Promote the first point
skipProbability *= SKIP_FIRST_POINT_PROBABILITY;
} else if (i == sampledInputSize - 1) {
skipProbability *= min(1.0f, nearestKeyDistance * NEAREST_DISTANCE_WEIGHT_FOR_LAST
+ NEAREST_DISTANCE_BIAS_FOR_LAST);
// Promote the last point
skipProbability *= SKIP_LAST_POINT_PROBABILITY;
} else {
// If the current speed is relatively slower than adjacent keys, we promote this point.
if ((*sampledSpeedRates)[i - 1] - SPEED_MARGIN > speedRate
&& speedRate < (*sampledSpeedRates)[i + 1] - SPEED_MARGIN) {
if (currentAngle < CORNER_ANGLE_THRESHOLD) {
skipProbability *= min(1.0f, speedRate
* SLOW_STRAIGHT_WEIGHT_FOR_SKIP_PROBABILITY);
} else {
// If the angle is small enough, we promote this point more. (e.g. pit vs put)
skipProbability *= min(1.0f, speedRate * SPEED_WEIGHT_FOR_SKIP_PROBABILITY
+ MIN_SPEED_RATE_FOR_SKIP_PROBABILITY);
}
}
skipProbability *= min(1.0f, speedRate * nearestKeyDistance *
NEAREST_DISTANCE_WEIGHT + NEAREST_DISTANCE_BIAS);
// Adjusts skip probability by a rate depending on angle.
// ANGLE_RATE of skipProbability is adjusted by current angle.
skipProbability *= (M_PI_F - currentAngle) / M_PI_F * ANGLE_WEIGHT
+ (1.0f - ANGLE_WEIGHT);
if (currentAngle > DEEP_CORNER_ANGLE_THRESHOLD) {
skipProbability *= SKIP_DEEP_CORNER_PROBABILITY;
}
// We assume the angle of this point is the angle for point[i], point[i - 2]
// and point[i - 3]. The reason why we don't use the angle for point[i], point[i - 1]
// and point[i - 2] is this angle can be more affected by the noise.
const float prevAngle = getPointsAngle(sampledInputXs, sampledInputYs, i, i - 2, i - 3);
if (i >= 3 && prevAngle < STRAIGHT_ANGLE_THRESHOLD
&& currentAngle > CORNER_ANGLE_THRESHOLD) {
skipProbability *= SKIP_CORNER_PROBABILITY;
}
}
// probabilities must be in [0.0, MAX_SKIP_PROBABILITY];
ASSERT(skipProbability >= 0.0f);
ASSERT(skipProbability <= MAX_SKIP_PROBABILITY);
(*charProbabilities)[i][NOT_AN_INDEX] = skipProbability;
// Second, calculates key probabilities by dividing the rest probability
// (1.0f - skipProbability).
const float inputCharProbability = 1.0f - skipProbability;
// TODO: The variance is critical for accuracy; thus, adjusting these parameter by machine
// learning or something would be efficient.
static const float SPEEDxANGLE_WEIGHT_FOR_STANDARD_DIVIATION = 0.3f;
static const float MAX_SPEEDxANGLE_RATE_FOR_STANDERD_DIVIATION = 0.25f;
static const float SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DIVIATION = 0.5f;
static const float MAX_SPEEDxNEAREST_RATE_FOR_STANDERD_DIVIATION = 0.15f;
static const float MIN_STANDERD_DIVIATION = 0.37f;
const float speedxAngleRate = min(speedRate * currentAngle / M_PI_F
* SPEEDxANGLE_WEIGHT_FOR_STANDARD_DIVIATION,
MAX_SPEEDxANGLE_RATE_FOR_STANDERD_DIVIATION);
const float speedxNearestKeyDistanceRate = min(speedRate * nearestKeyDistance
* SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DIVIATION,
MAX_SPEEDxNEAREST_RATE_FOR_STANDERD_DIVIATION);
const float sigma = speedxAngleRate + speedxNearestKeyDistanceRate + MIN_STANDERD_DIVIATION;
ProximityInfoUtils::NormalDistribution
distribution(CENTER_VALUE_OF_NORMALIZED_DISTRIBUTION, sigma);
static const float PREV_DISTANCE_WEIGHT = 0.5f;
static const float NEXT_DISTANCE_WEIGHT = 0.6f;
// Summing up probability densities of all near keys.
float sumOfProbabilityDensities = 0.0f;
for (int j = 0; j < keyCount; ++j) {
if ((*nearKeysVector)[i].test(j)) {
float distance = sqrtf(getPointToKeyByIdLength(
maxPointToKeyLength, distanceCache_G, keyCount, i, j));
if (i == 0 && i != sampledInputSize - 1) {
// For the first point, weighted average of distances from first point and the
// next point to the key is used as a point to key distance.
const float nextDistance = sqrtf(getPointToKeyByIdLength(
maxPointToKeyLength, distanceCache_G, keyCount, i + 1, j));
if (nextDistance < distance) {
// The distance of the first point tends to bigger than continuing
// points because the first touch by the user can be sloppy.
// So we promote the first point if the distance of that point is larger
// than the distance of the next point.
distance = (distance + nextDistance * NEXT_DISTANCE_WEIGHT)
/ (1.0f + NEXT_DISTANCE_WEIGHT);
}
} else if (i != 0 && i == sampledInputSize - 1) {
// For the first point, weighted average of distances from last point and
// the previous point to the key is used as a point to key distance.
const float previousDistance = sqrtf(getPointToKeyByIdLength(
maxPointToKeyLength, distanceCache_G, keyCount, i - 1, j));
if (previousDistance < distance) {
// The distance of the last point tends to bigger than continuing points
// because the last touch by the user can be sloppy. So we promote the
// last point if the distance of that point is larger than the distance of
// the previous point.
distance = (distance + previousDistance * PREV_DISTANCE_WEIGHT)
/ (1.0f + PREV_DISTANCE_WEIGHT);
}
}
// TODO: Promote the first point when the extended line from the next input is near
// from a key. Also, promote the last point as well.
sumOfProbabilityDensities += distribution.getProbabilityDensity(distance);
}
}
// Split the probability of an input point to keys that are close to the input point.
for (int j = 0; j < keyCount; ++j) {
if ((*nearKeysVector)[i].test(j)) {
float distance = sqrtf(getPointToKeyByIdLength(
maxPointToKeyLength, distanceCache_G, keyCount, i, j));
if (i == 0 && i != sampledInputSize - 1) {
// For the first point, weighted average of distances from the first point and
// the next point to the key is used as a point to key distance.
const float prevDistance = sqrtf(getPointToKeyByIdLength(
maxPointToKeyLength, distanceCache_G, keyCount, i + 1, j));
if (prevDistance < distance) {
distance = (distance + prevDistance * NEXT_DISTANCE_WEIGHT)
/ (1.0f + NEXT_DISTANCE_WEIGHT);
}
} else if (i != 0 && i == sampledInputSize - 1) {
// For the first point, weighted average of distances from last point and
// the previous point to the key is used as a point to key distance.
const float prevDistance = sqrtf(getPointToKeyByIdLength(
maxPointToKeyLength, distanceCache_G, keyCount, i - 1, j));
if (prevDistance < distance) {
distance = (distance + prevDistance * PREV_DISTANCE_WEIGHT)
/ (1.0f + PREV_DISTANCE_WEIGHT);
}
}
const float probabilityDensity = distribution.getProbabilityDensity(distance);
const float probability = inputCharProbability * probabilityDensity
/ sumOfProbabilityDensities;
(*charProbabilities)[i][j] = probability;
}
}
}
if (DEBUG_POINTS_PROBABILITY) {
for (int i = 0; i < sampledInputSize; ++i) {
std::stringstream sstream;
sstream << i << ", ";
sstream << "(" << (*sampledInputXs)[i] << ", " << (*sampledInputYs)[i] << "), ";
sstream << "Speed: "<< (*sampledSpeedRates)[i] << ", ";
sstream << "Angle: "<< getPointAngle(sampledInputXs, sampledInputYs, i) << ", \n";
for (hash_map_compat<int, float>::iterator it = (*charProbabilities)[i].begin();
it != (*charProbabilities)[i].end(); ++it) {
if (it->first == NOT_AN_INDEX) {
sstream << it->first
<< "(skip):"
<< it->second
<< "\n";
} else {
sstream << it->first
<< "("
//<< static_cast<char>(mProximityInfo->getCodePointOf(it->first))
<< "):"
<< it->second
<< "\n";
}
}
AKLOGI("%s", sstream.str().c_str());
}
}
// Decrease key probabilities of points which don't have the highest probability of that key
// among nearby points. Probabilities of the first point and the last point are not suppressed.
for (int i = max(start, 1); i < sampledInputSize; ++i) {
for (int j = i + 1; j < sampledInputSize; ++j) {
if (!suppressCharProbabilities(
mostCommonKeyWidth, sampledInputSize, sampledLengthCache, i, j,
charProbabilities)) {
break;
}
}
for (int j = i - 1; j >= max(start, 0); --j) {
if (!suppressCharProbabilities(
mostCommonKeyWidth, sampledInputSize, sampledLengthCache, i, j,
charProbabilities)) {
break;
}
}
}
// Converting from raw probabilities to log probabilities to calculate spatial distance.
for (int i = start; i < sampledInputSize; ++i) {
for (int j = 0; j < keyCount; ++j) {
hash_map_compat<int, float>::iterator it = (*charProbabilities)[i].find(j);
if (it == (*charProbabilities)[i].end()){
(*nearKeysVector)[i].reset(j);
} else if(it->second < MIN_PROBABILITY) {
// Erases from near keys vector because it has very low probability.
(*nearKeysVector)[i].reset(j);
(*charProbabilities)[i].erase(j);
} else {
it->second = -logf(it->second);
}
}
(*charProbabilities)[i][NOT_AN_INDEX] = -logf((*charProbabilities)[i][NOT_AN_INDEX]);
}
}
// Decreases char probabilities of index0 by checking probabilities of a near point (index1) and
// increases char probabilities of index1 by checking probabilities of index0.
/* static */ bool ProximityInfoStateUtils::suppressCharProbabilities(const int mostCommonKeyWidth,
const int sampledInputSize, const std::vector<int> *const lengthCache,
const int index0, const int index1,
std::vector<hash_map_compat<int, float> > *charProbabilities) {
ASSERT(0 <= index0 && index0 < sampledInputSize);
ASSERT(0 <= index1 && index1 < sampledInputSize);
static const float SUPPRESSION_LENGTH_WEIGHT = 1.5f;
static const float MIN_SUPPRESSION_RATE = 0.1f;
static const float SUPPRESSION_WEIGHT = 0.5f;
static const float SUPPRESSION_WEIGHT_FOR_PROBABILITY_GAIN = 0.1f;
static const float SKIP_PROBABALITY_WEIGHT_FOR_PROBABILITY_GAIN = 0.3f;
const float keyWidthFloat = static_cast<float>(mostCommonKeyWidth);
const float diff = fabsf(static_cast<float>((*lengthCache)[index0] - (*lengthCache)[index1]));
if (diff > keyWidthFloat * SUPPRESSION_LENGTH_WEIGHT) {
return false;
}
const float suppressionRate = MIN_SUPPRESSION_RATE
+ diff / keyWidthFloat / SUPPRESSION_LENGTH_WEIGHT * SUPPRESSION_WEIGHT;
for (hash_map_compat<int, float>::iterator it = (*charProbabilities)[index0].begin();
it != (*charProbabilities)[index0].end(); ++it) {
hash_map_compat<int, float>::iterator it2 = (*charProbabilities)[index1].find(it->first);
if (it2 != (*charProbabilities)[index1].end() && it->second < it2->second) {
const float newProbability = it->second * suppressionRate;
const float suppression = it->second - newProbability;
it->second = newProbability;
// mCharProbabilities[index0][NOT_AN_INDEX] is the probability of skipping this point.
(*charProbabilities)[index0][NOT_AN_INDEX] += suppression;
// Add the probability of the same key nearby index1
const float probabilityGain = min(suppression * SUPPRESSION_WEIGHT_FOR_PROBABILITY_GAIN,
(*charProbabilities)[index1][NOT_AN_INDEX]
* SKIP_PROBABALITY_WEIGHT_FOR_PROBABILITY_GAIN);
it2->second += probabilityGain;
(*charProbabilities)[index1][NOT_AN_INDEX] -= probabilityGain;
}
}
return true;
}
} // namespace latinime } // namespace latinime

View file

@ -17,9 +17,11 @@
#ifndef LATINIME_PROXIMITY_INFO_STATE_UTILS_H #ifndef LATINIME_PROXIMITY_INFO_STATE_UTILS_H
#define LATINIME_PROXIMITY_INFO_STATE_UTILS_H #define LATINIME_PROXIMITY_INFO_STATE_UTILS_H
#include <bitset>
#include <vector> #include <vector>
#include "defines.h" #include "defines.h"
#include "hash_map_compat.h"
namespace latinime { namespace latinime {
class ProximityInfo; class ProximityInfo;
@ -27,6 +29,9 @@ class ProximityInfoParams;
class ProximityInfoStateUtils { class ProximityInfoStateUtils {
public: public:
typedef hash_map_compat<int, float> NearKeysDistanceMap;
typedef std::bitset<MAX_KEY_COUNT_IN_A_KEYBOARD> NearKeycodesSet;
static int updateTouchPoints(const int mostCommonKeyWidth, static int updateTouchPoints(const int mostCommonKeyWidth,
const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const ProximityInfo *const proximityInfo, const int maxPointToKeyLength,
const int *const inputProximities, const int *const inputProximities,
@ -57,12 +62,26 @@ class ProximityInfoStateUtils {
std::vector<int> *beelineSpeedPercentiles); std::vector<int> *beelineSpeedPercentiles);
static float getDirection(const std::vector<int> *const sampledInputXs, static float getDirection(const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs, const int index0, const int index1); const std::vector<int> *const sampledInputYs, const int index0, const int index1);
static void updateAlignPointProbabilities(
const float maxPointToKeyLength, const int mostCommonKeyWidth, const int keyCount,
const int start, const int sampledInputSize,
const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs,
const std::vector<float> *const sampledSpeedRates,
const std::vector<int> *const sampledLengthCache,
const std::vector<float> *const distanceCache_G,
std::vector<NearKeycodesSet> *nearKeysVector,
std::vector<hash_map_compat<int, float> > *charProbabilities);
static float getPointToKeyByIdLength(const float maxPointToKeyLength,
const std::vector<float> *const distanceCache_G, const int keyCount,
const int inputIndex, const int keyId, const float scale);
static float getPointToKeyByIdLength(const float maxPointToKeyLength,
const std::vector<float> *const distanceCache_G, const int keyCount,
const int inputIndex, const int keyId);
private: private:
DISALLOW_IMPLICIT_CONSTRUCTORS(ProximityInfoStateUtils); DISALLOW_IMPLICIT_CONSTRUCTORS(ProximityInfoStateUtils);
typedef hash_map_compat<int, float> NearKeysDistanceMap;
static float updateNearKeysDistances(const ProximityInfo *const proximityInfo, static float updateNearKeysDistances(const ProximityInfo *const proximityInfo,
const float maxPointToKeyLength, const int x, const int y, const float maxPointToKeyLength, const int x, const int y,
NearKeysDistanceMap *const currentNearKeysDistances); NearKeysDistanceMap *const currentNearKeysDistances);
@ -91,6 +110,17 @@ class ProximityInfoStateUtils {
const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs, const std::vector<int> *const sampledInputYs,
const std::vector<int> *const inputIndice); const std::vector<int> *const inputIndice);
static float getPointAngle(
const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs, const int index);
static float getPointsAngle(
const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs,
const int index0, const int index1, const int index2);
static bool suppressCharProbabilities(const int mostCommonKeyWidth,
const int sampledInputSize, const std::vector<int> *const lengthCache,
const int index0, const int index1,
std::vector<hash_map_compat<int, float> > *charProbabilities);
}; };
} // namespace latinime } // namespace latinime
#endif // LATINIME_PROXIMITY_INFO_STATE_UTILS_H #endif // LATINIME_PROXIMITY_INFO_STATE_UTILS_H