Merge "clean up seach key vector"

This commit is contained in:
Satoshi Kataoka 2013-02-14 08:47:53 +00:00 committed by Android (Google) Code Review
commit f9097056f2
6 changed files with 57 additions and 62 deletions

View file

@ -88,7 +88,7 @@ const float ProximityInfoParams::SKIP_PROBABALITY_WEIGHT_FOR_PROBABILITY_GAIN =
// Used by ProximityInfoStateUtils::getMostProbableString() // Used by ProximityInfoStateUtils::getMostProbableString()
const float ProximityInfoParams::DEMOTION_LOG_PROBABILITY = 0.3f; const float ProximityInfoParams::DEMOTION_LOG_PROBABILITY = 0.3f;
// Used by ProximityInfoStateUtils::updateSampledSearchKeysVector() // Used by ProximityInfoStateUtils::updateSampledSearchKeySets()
// TODO: Investigate if this is required // TODO: Investigate if this is required
const float ProximityInfoParams::SEARCH_KEY_RADIUS_RATIO = 0.95f; const float ProximityInfoParams::SEARCH_KEY_RADIUS_RATIO = 0.95f;

View file

@ -90,7 +90,7 @@ class ProximityInfoParams {
// Used by ProximityInfoStateUtils::getMostProbableString() // Used by ProximityInfoStateUtils::getMostProbableString()
static const float DEMOTION_LOG_PROBABILITY; static const float DEMOTION_LOG_PROBABILITY;
// Used by ProximityInfoStateUtils::updateSampledSearchKeysVector() // Used by ProximityInfoStateUtils::updateSampledSearchKeySets()
static const float SEARCH_KEY_RADIUS_RATIO; static const float SEARCH_KEY_RADIUS_RATIO;
// Used by ProximityInfoStateUtils::calculateBeelineSpeedRate() // Used by ProximityInfoStateUtils::calculateBeelineSpeedRate()

View file

@ -16,6 +16,7 @@
#include <cstring> // for memset() and memcpy() #include <cstring> // for memset() and memcpy()
#include <sstream> // for debug prints #include <sstream> // for debug prints
#include <vector>
#define LOG_TAG "LatinIME: proximity_info_state.cpp" #define LOG_TAG "LatinIME: proximity_info_state.cpp"
@ -75,8 +76,8 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
mSampledInputIndice.clear(); mSampledInputIndice.clear();
mSampledLengthCache.clear(); mSampledLengthCache.clear();
mSampledDistanceCache_G.clear(); mSampledDistanceCache_G.clear();
mSampledNearKeysVector.clear(); mSampledNearKeySets.clear();
mSampledSearchKeysVector.clear(); mSampledSearchKeySets.clear();
mSpeedRates.clear(); mSpeedRates.clear();
mBeelineSpeedPercentiles.clear(); mBeelineSpeedPercentiles.clear();
mCharProbabilities.clear(); mCharProbabilities.clear();
@ -109,7 +110,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
if (mSampledInputSize > 0) { if (mSampledInputSize > 0) {
ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize, ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize,
lastSavedInputSize, &mSampledInputXs, &mSampledInputYs, &mSampledNearKeysVector, lastSavedInputSize, &mSampledInputXs, &mSampledInputYs, &mSampledNearKeySets,
&mSampledDistanceCache_G); &mSampledDistanceCache_G);
if (isGeometric) { if (isGeometric) {
// updates probabilities of skipping or mapping each key for all points. // updates probabilities of skipping or mapping each key for all points.
@ -117,10 +118,11 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(), mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(),
mProximityInfo->getKeyCount(), lastSavedInputSize, mSampledInputSize, mProximityInfo->getKeyCount(), lastSavedInputSize, mSampledInputSize,
&mSampledInputXs, &mSampledInputYs, &mSpeedRates, &mSampledLengthCache, &mSampledInputXs, &mSampledInputYs, &mSpeedRates, &mSampledLengthCache,
&mSampledDistanceCache_G, &mSampledNearKeysVector, &mCharProbabilities); &mSampledDistanceCache_G, &mSampledNearKeySets, &mCharProbabilities);
ProximityInfoStateUtils::updateSampledSearchKeysVector(mProximityInfo, ProximityInfoStateUtils::updateSampledSearchKeySets(mProximityInfo,
mSampledInputSize, lastSavedInputSize, &mSampledLengthCache, mSampledInputSize, lastSavedInputSize, &mSampledLengthCache,
&mSampledNearKeysVector, &mSampledSearchKeysVector); &mSampledNearKeySets, &mSampledSearchKeySets,
&mSampledSearchKeyVectors);
mMostProbableStringProbability = ProximityInfoStateUtils::getMostProbableString( mMostProbableStringProbability = ProximityInfoStateUtils::getMostProbableString(
mProximityInfo, mSampledInputSize, &mCharProbabilities, mMostProbableString); mProximityInfo, mSampledInputSize, &mCharProbabilities, mMostProbableString);
@ -245,36 +247,9 @@ ProximityType ProximityInfoState::getMatchedProximityId(const int index, const i
return UNRELATED_CHAR; return UNRELATED_CHAR;
} }
// Puts possible characters into filter and returns new filter size.
int ProximityInfoState::getAllPossibleChars(
const size_t index, int *const filter, const int filterSize) const {
if (index >= mSampledInputXs.size()) {
return filterSize;
}
int newFilterSize = filterSize;
const int keyCount = mProximityInfo->getKeyCount();
for (int j = 0; j < keyCount; ++j) {
if (mSampledSearchKeysVector[index].test(j)) {
const int keyCodePoint = mProximityInfo->getCodePointOf(j);
bool insert = true;
// TODO: Avoid linear search
for (int k = 0; k < filterSize; ++k) {
if (filter[k] == keyCodePoint) {
insert = false;
break;
}
}
if (insert) {
filter[newFilterSize++] = keyCodePoint;
}
}
}
return newFilterSize;
}
bool ProximityInfoState::isKeyInSerchKeysAfterIndex(const int index, const int keyId) const { bool ProximityInfoState::isKeyInSerchKeysAfterIndex(const int index, const int keyId) const {
ASSERT(keyId >= 0 && index >= 0 && index < mSampledInputSize); ASSERT(keyId >= 0 && index >= 0 && index < mSampledInputSize);
return mSampledSearchKeysVector[index].test(keyId); return mSampledSearchKeySets[index].test(keyId);
} }
float ProximityInfoState::getDirection(const int index0, const int index1) const { float ProximityInfoState::getDirection(const int index0, const int index1) const {

View file

@ -50,7 +50,7 @@ class ProximityInfoState {
mIsContinuationPossible(false), mSampledInputXs(), mSampledInputYs(), mSampledTimes(), mIsContinuationPossible(false), mSampledInputXs(), mSampledInputYs(), mSampledTimes(),
mSampledInputIndice(), mSampledLengthCache(), mBeelineSpeedPercentiles(), mSampledInputIndice(), mSampledLengthCache(), mBeelineSpeedPercentiles(),
mSampledDistanceCache_G(), mSpeedRates(), mDirections(), mCharProbabilities(), mSampledDistanceCache_G(), mSpeedRates(), mDirections(), mCharProbabilities(),
mSampledNearKeysVector(), mSampledSearchKeysVector(), mSampledNearKeySets(), mSampledSearchKeySets(), mSampledSearchKeyVectors(),
mTouchPositionCorrectionEnabled(false), mSampledInputSize(0), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0),
mMostProbableStringProbability(0.0f) { mMostProbableStringProbability(0.0f) {
memset(mInputProximities, 0, sizeof(mInputProximities)); memset(mInputProximities, 0, sizeof(mInputProximities));
@ -155,7 +155,9 @@ class ProximityInfoState {
ProximityType getMatchedProximityId(const int index, const int c, ProximityType getMatchedProximityId(const int index, const int c,
const bool checkProximityChars, int *proximityIndex = 0) const; const bool checkProximityChars, int *proximityIndex = 0) const;
int getAllPossibleChars(const size_t startIndex, int *const filter, const int filterSize) const; const std::vector<int> *getSearchKeyVector(const int index) const {
return &mSampledSearchKeyVectors[index];
}
float getSpeedRate(const int index) const { float getSpeedRate(const int index) const {
return mSpeedRates[index]; return mSpeedRates[index];
@ -236,13 +238,14 @@ class ProximityInfoState {
std::vector<hash_map_compat<int, float> > mCharProbabilities; std::vector<hash_map_compat<int, float> > mCharProbabilities;
// The vector for the key code set which holds nearby keys for each sampled input point // The vector for the key code set which holds nearby keys for each sampled input point
// 1. Used to calculate the probability of the key // 1. Used to calculate the probability of the key
// 2. Used to calculate mSampledSearchKeysVector // 2. Used to calculate mSampledSearchKeySets
std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledNearKeysVector; std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledNearKeySets;
// The vector for the key code set which holds nearby keys of some trailing sampled input points // The vector for the key code set which holds nearby keys of some trailing sampled input points
// for each sampled input point. These nearby keys contain the next characters which can be in // for each sampled input point. These nearby keys contain the next characters which can be in
// the dictionary. Specifically, currently we are looking for keys nearby trailing sampled // the dictionary. Specifically, currently we are looking for keys nearby trailing sampled
// inputs including the current input point. // inputs including the current input point.
std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledSearchKeysVector; std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledSearchKeySets;
std::vector<std::vector<int> > mSampledSearchKeyVectors;
bool mTouchPositionCorrectionEnabled; bool mTouchPositionCorrectionEnabled;
int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];

View file

@ -224,13 +224,13 @@ namespace latinime {
const ProximityInfo *const proximityInfo, const int sampledInputSize, const ProximityInfo *const proximityInfo, const int sampledInputSize,
const int lastSavedInputSize, const std::vector<int> *const sampledInputXs, const int lastSavedInputSize, const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs, const std::vector<int> *const sampledInputYs,
std::vector<NearKeycodesSet> *SampledNearKeysVector, std::vector<NearKeycodesSet> *SampledNearKeySets,
std::vector<float> *SampledDistanceCache_G) { std::vector<float> *SampledDistanceCache_G) {
SampledNearKeysVector->resize(sampledInputSize); SampledNearKeySets->resize(sampledInputSize);
const int keyCount = proximityInfo->getKeyCount(); const int keyCount = proximityInfo->getKeyCount();
SampledDistanceCache_G->resize(sampledInputSize * keyCount); SampledDistanceCache_G->resize(sampledInputSize * keyCount);
for (int i = lastSavedInputSize; i < sampledInputSize; ++i) { for (int i = lastSavedInputSize; i < sampledInputSize; ++i) {
(*SampledNearKeysVector)[i].reset(); (*SampledNearKeySets)[i].reset();
for (int k = 0; k < keyCount; ++k) { for (int k = 0; k < keyCount; ++k) {
const int index = i * keyCount + k; const int index = i * keyCount + k;
const int x = (*sampledInputXs)[i]; const int x = (*sampledInputXs)[i];
@ -240,7 +240,7 @@ namespace latinime {
(*SampledDistanceCache_G)[index] = normalizedSquaredDistance; (*SampledDistanceCache_G)[index] = normalizedSquaredDistance;
if (normalizedSquaredDistance if (normalizedSquaredDistance
< ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { < ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) {
(*SampledNearKeysVector)[i][k] = true; (*SampledNearKeySets)[i][k] = true;
} }
} }
} }
@ -664,7 +664,7 @@ namespace latinime {
const std::vector<float> *const sampledSpeedRates, const std::vector<float> *const sampledSpeedRates,
const std::vector<int> *const sampledLengthCache, const std::vector<int> *const sampledLengthCache,
const std::vector<float> *const SampledDistanceCache_G, const std::vector<float> *const SampledDistanceCache_G,
std::vector<NearKeycodesSet> *SampledNearKeysVector, std::vector<NearKeycodesSet> *SampledNearKeySets,
std::vector<hash_map_compat<int, float> > *charProbabilities) { std::vector<hash_map_compat<int, float> > *charProbabilities) {
charProbabilities->resize(sampledInputSize); charProbabilities->resize(sampledInputSize);
// Calculates probabilities of using a point as a correlated point with the character // Calculates probabilities of using a point as a correlated point with the character
@ -680,7 +680,7 @@ namespace latinime {
float nearestKeyDistance = static_cast<float>(MAX_POINT_TO_KEY_LENGTH); float nearestKeyDistance = static_cast<float>(MAX_POINT_TO_KEY_LENGTH);
for (int j = 0; j < keyCount; ++j) { for (int j = 0; j < keyCount; ++j) {
if ((*SampledNearKeysVector)[i].test(j)) { if ((*SampledNearKeySets)[i].test(j)) {
const float distance = getPointToKeyByIdLength( const float distance = getPointToKeyByIdLength(
maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j); maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j);
if (distance < nearestKeyDistance) { if (distance < nearestKeyDistance) {
@ -761,7 +761,7 @@ namespace latinime {
// Summing up probability densities of all near keys. // Summing up probability densities of all near keys.
float sumOfProbabilityDensities = 0.0f; float sumOfProbabilityDensities = 0.0f;
for (int j = 0; j < keyCount; ++j) { for (int j = 0; j < keyCount; ++j) {
if ((*SampledNearKeysVector)[i].test(j)) { if ((*SampledNearKeySets)[i].test(j)) {
float distance = sqrtf(getPointToKeyByIdLength( float distance = sqrtf(getPointToKeyByIdLength(
maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j)); maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j));
if (i == 0 && i != sampledInputSize - 1) { if (i == 0 && i != sampledInputSize - 1) {
@ -801,7 +801,7 @@ namespace latinime {
// Split the probability of an input point to keys that are close to the input point. // Split the probability of an input point to keys that are close to the input point.
for (int j = 0; j < keyCount; ++j) { for (int j = 0; j < keyCount; ++j) {
if ((*SampledNearKeysVector)[i].test(j)) { if ((*SampledNearKeySets)[i].test(j)) {
float distance = sqrtf(getPointToKeyByIdLength( float distance = sqrtf(getPointToKeyByIdLength(
maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j)); maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j));
if (i == 0 && i != sampledInputSize - 1) { if (i == 0 && i != sampledInputSize - 1) {
@ -885,10 +885,10 @@ namespace latinime {
for (int j = 0; j < keyCount; ++j) { for (int j = 0; j < keyCount; ++j) {
hash_map_compat<int, float>::iterator it = (*charProbabilities)[i].find(j); hash_map_compat<int, float>::iterator it = (*charProbabilities)[i].find(j);
if (it == (*charProbabilities)[i].end()){ if (it == (*charProbabilities)[i].end()){
(*SampledNearKeysVector)[i].reset(j); (*SampledNearKeySets)[i].reset(j);
} else if(it->second < ProximityInfoParams::MIN_PROBABILITY) { } else if(it->second < ProximityInfoParams::MIN_PROBABILITY) {
// Erases from near keys vector because it has very low probability. // Erases from near keys vector because it has very low probability.
(*SampledNearKeysVector)[i].reset(j); (*SampledNearKeySets)[i].reset(j);
(*charProbabilities)[i].erase(j); (*charProbabilities)[i].erase(j);
} else { } else {
it->second = -logf(it->second); it->second = -logf(it->second);
@ -898,26 +898,42 @@ namespace latinime {
} }
} }
/* static */ void ProximityInfoStateUtils::updateSampledSearchKeysVector( /* static */ void ProximityInfoStateUtils::updateSampledSearchKeySets(
const ProximityInfo *const proximityInfo, const int sampledInputSize, const ProximityInfo *const proximityInfo, const int sampledInputSize,
const int lastSavedInputSize, const int lastSavedInputSize,
const std::vector<int> *const sampledLengthCache, const std::vector<int> *const sampledLengthCache,
const std::vector<NearKeycodesSet> *const SampledNearKeysVector, const std::vector<NearKeycodesSet> *const SampledNearKeySets,
std::vector<NearKeycodesSet> *sampledSearchKeysVector) { std::vector<NearKeycodesSet> *sampledSearchKeySets,
sampledSearchKeysVector->resize(sampledInputSize); std::vector<std::vector<int> > *sampledSearchKeyVectors) {
sampledSearchKeySets->resize(sampledInputSize);
sampledSearchKeyVectors->resize(sampledInputSize);
const int readForwordLength = static_cast<int>( const int readForwordLength = static_cast<int>(
hypotf(proximityInfo->getKeyboardWidth(), proximityInfo->getKeyboardHeight()) hypotf(proximityInfo->getKeyboardWidth(), proximityInfo->getKeyboardHeight())
* ProximityInfoParams::SEARCH_KEY_RADIUS_RATIO); * ProximityInfoParams::SEARCH_KEY_RADIUS_RATIO);
for (int i = 0; i < sampledInputSize; ++i) { for (int i = 0; i < sampledInputSize; ++i) {
if (i >= lastSavedInputSize) { if (i >= lastSavedInputSize) {
(*sampledSearchKeysVector)[i].reset(); (*sampledSearchKeySets)[i].reset();
} }
for (int j = max(i, lastSavedInputSize); j < sampledInputSize; ++j) { for (int j = max(i, lastSavedInputSize); j < sampledInputSize; ++j) {
// TODO: Investigate if this is required. This may not fail. // TODO: Investigate if this is required. This may not fail.
if ((*sampledLengthCache)[j] - (*sampledLengthCache)[i] >= readForwordLength) { if ((*sampledLengthCache)[j] - (*sampledLengthCache)[i] >= readForwordLength) {
break; break;
} }
(*sampledSearchKeysVector)[i] |= (*SampledNearKeysVector)[j]; (*sampledSearchKeySets)[i] |= (*SampledNearKeySets)[j];
}
}
const int keyCount = proximityInfo->getKeyCount();
for (int i = 0; i < sampledInputSize; ++i) {
std::vector<int> *searchKeyVector = &(*sampledSearchKeyVectors)[i];
searchKeyVector->clear();
for (int j = 0; j < keyCount; ++j) {
if ((*sampledSearchKeySets)[i].test(j)) {
const int keyCodePoint = proximityInfo->getCodePointOf(j);
if (std::find(searchKeyVector->begin(), searchKeyVector->end(), keyCodePoint)
== searchKeyVector->end()) {
searchKeyVector->push_back(keyCodePoint);
}
}
} }
} }
} }

View file

@ -71,13 +71,14 @@ class ProximityInfoStateUtils {
const std::vector<float> *const sampledSpeedRates, const std::vector<float> *const sampledSpeedRates,
const std::vector<int> *const sampledLengthCache, const std::vector<int> *const sampledLengthCache,
const std::vector<float> *const SampledDistanceCache_G, const std::vector<float> *const SampledDistanceCache_G,
std::vector<NearKeycodesSet> *SampledNearKeysVector, std::vector<NearKeycodesSet> *SampledNearKeySets,
std::vector<hash_map_compat<int, float> > *charProbabilities); std::vector<hash_map_compat<int, float> > *charProbabilities);
static void updateSampledSearchKeysVector(const ProximityInfo *const proximityInfo, static void updateSampledSearchKeySets(const ProximityInfo *const proximityInfo,
const int sampledInputSize, const int lastSavedInputSize, const int sampledInputSize, const int lastSavedInputSize,
const std::vector<int> *const sampledLengthCache, const std::vector<int> *const sampledLengthCache,
const std::vector<NearKeycodesSet> *const SampledNearKeysVector, const std::vector<NearKeycodesSet> *const SampledNearKeySets,
std::vector<NearKeycodesSet> *sampledSearchKeysVector); std::vector<NearKeycodesSet> *sampledSearchKeySets,
std::vector<std::vector<int> > *sampledSearchKeyVectors);
static float getPointToKeyByIdLength(const float maxPointToKeyLength, static float getPointToKeyByIdLength(const float maxPointToKeyLength,
const std::vector<float> *const SampledDistanceCache_G, const int keyCount, const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
const int inputIndex, const int keyId, const float scale); const int inputIndex, const int keyId, const float scale);
@ -88,7 +89,7 @@ class ProximityInfoStateUtils {
const int sampledInputSize, const int lastSavedInputSize, const int sampledInputSize, const int lastSavedInputSize,
const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs, const std::vector<int> *const sampledInputYs,
std::vector<NearKeycodesSet> *SampledNearKeysVector, std::vector<NearKeycodesSet> *SampledNearKeySets,
std::vector<float> *SampledDistanceCache_G); std::vector<float> *SampledDistanceCache_G);
static void initPrimaryInputWord(const int inputSize, const int *const inputProximities, static void initPrimaryInputWord(const int inputSize, const int *const inputProximities,
int *primaryInputWord); int *primaryInputWord);