Improve skip correction algorithm
Change-Id: Ife45e3886137d60a4e903d4c6f7a9ef20c7e705amain
parent
c359d75ca7
commit
635f68e822
|
@ -49,12 +49,11 @@ void Correction::initCorrection(const ProximityInfo *pi, const int inputLength,
|
||||||
mInputLength = inputLength;
|
mInputLength = inputLength;
|
||||||
mMaxDepth = maxDepth;
|
mMaxDepth = maxDepth;
|
||||||
mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
|
mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
|
||||||
mSkippedOutputIndex = -1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Correction::initCorrectionState(
|
void Correction::initCorrectionState(
|
||||||
const int rootPos, const int childCount, const bool traverseAll) {
|
const int rootPos, const int childCount, const bool traverseAll) {
|
||||||
mCorrectionStates[0].init(rootPos, childCount, traverseAll);
|
latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Correction::setCorrectionParams(const int skipPos, const int excessivePos,
|
void Correction::setCorrectionParams(const int skipPos, const int excessivePos,
|
||||||
|
@ -88,6 +87,12 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen
|
||||||
if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) {
|
if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Remove this
|
||||||
|
if (mSkipPos >= 0 && mSkippedCount <= 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
*word = mWord;
|
*word = mWord;
|
||||||
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
|
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
|
||||||
: (mInputLength == inputIndex + 1);
|
: (mInputLength == inputIndex + 1);
|
||||||
|
@ -103,8 +108,11 @@ bool Correction::initProcessState(const int outputIndex) {
|
||||||
--(mCorrectionStates[outputIndex].mChildCount);
|
--(mCorrectionStates[outputIndex].mChildCount);
|
||||||
mMatchedCharCount = mCorrectionStates[outputIndex].mMatchedCount;
|
mMatchedCharCount = mCorrectionStates[outputIndex].mMatchedCount;
|
||||||
mInputIndex = mCorrectionStates[outputIndex].mInputIndex;
|
mInputIndex = mCorrectionStates[outputIndex].mInputIndex;
|
||||||
mTraverseAllNodes = mCorrectionStates[outputIndex].mTraverseAll;
|
mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes;
|
||||||
mDiffs = mCorrectionStates[outputIndex].mDiffs;
|
mDiffs = mCorrectionStates[outputIndex].mDiffs;
|
||||||
|
mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount;
|
||||||
|
mSkipping = false;
|
||||||
|
mMatching = false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,8 +139,8 @@ int Correction::getInputIndex() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove
|
// TODO: remove
|
||||||
bool Correction::needsToTraverseAll() {
|
bool Correction::needsToTraverseAllNodes() {
|
||||||
return mTraverseAllNodes;
|
return mNeedsToTraverseAllNodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Correction::incrementInputIndex() {
|
void Correction::incrementInputIndex() {
|
||||||
|
@ -146,12 +154,15 @@ void Correction::incrementOutputIndex() {
|
||||||
mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
|
mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
|
||||||
mCorrectionStates[mOutputIndex].mMatchedCount = mMatchedCharCount;
|
mCorrectionStates[mOutputIndex].mMatchedCount = mMatchedCharCount;
|
||||||
mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
|
mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
|
||||||
mCorrectionStates[mOutputIndex].mTraverseAll = mTraverseAllNodes;
|
mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;
|
||||||
mCorrectionStates[mOutputIndex].mDiffs = mDiffs;
|
mCorrectionStates[mOutputIndex].mDiffs = mDiffs;
|
||||||
|
mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;
|
||||||
|
mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
|
||||||
|
mCorrectionStates[mOutputIndex].mMatching = mMatching;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Correction::startTraverseAll() {
|
void Correction::startToTraverseAllNodes() {
|
||||||
mTraverseAllNodes = true;
|
mNeedsToTraverseAllNodes = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Correction::needsToPrune() const {
|
bool Correction::needsToPrune() const {
|
||||||
|
@ -162,7 +173,7 @@ bool Correction::needsToPrune() const {
|
||||||
Correction::CorrectionType Correction::processSkipChar(
|
Correction::CorrectionType Correction::processSkipChar(
|
||||||
const int32_t c, const bool isTerminal) {
|
const int32_t c, const bool isTerminal) {
|
||||||
mWord[mOutputIndex] = c;
|
mWord[mOutputIndex] = c;
|
||||||
if (needsToTraverseAll() && isTerminal) {
|
if (needsToTraverseAllNodes() && isTerminal) {
|
||||||
mTerminalInputIndex = mInputIndex;
|
mTerminalInputIndex = mInputIndex;
|
||||||
mTerminalOutputIndex = mOutputIndex;
|
mTerminalOutputIndex = mOutputIndex;
|
||||||
incrementOutputIndex();
|
incrementOutputIndex();
|
||||||
|
@ -185,9 +196,10 @@ Correction::CorrectionType Correction::processCharAndCalcState(
|
||||||
bool skip = false;
|
bool skip = false;
|
||||||
if (mSkipPos >= 0) {
|
if (mSkipPos >= 0) {
|
||||||
skip = mSkipPos == mOutputIndex;
|
skip = mSkipPos == mOutputIndex;
|
||||||
|
mSkipping = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mTraverseAllNodes || isQuote(c)) {
|
if (mNeedsToTraverseAllNodes || isQuote(c)) {
|
||||||
return processSkipChar(c, isTerminal);
|
return processSkipChar(c, isTerminal);
|
||||||
} else {
|
} else {
|
||||||
int inputIndexForProximity = mInputIndex;
|
int inputIndexForProximity = mInputIndex;
|
||||||
|
@ -210,25 +222,23 @@ Correction::CorrectionType Correction::processCharAndCalcState(
|
||||||
if (unrelated) {
|
if (unrelated) {
|
||||||
if (skip) {
|
if (skip) {
|
||||||
// Skip this letter and continue deeper
|
// Skip this letter and continue deeper
|
||||||
mSkippedOutputIndex = mOutputIndex;
|
++mSkippedCount;
|
||||||
return processSkipChar(c, isTerminal);
|
return processSkipChar(c, isTerminal);
|
||||||
} else {
|
} else {
|
||||||
return UNRELATED;
|
return UNRELATED;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// No need to skip. Finish traversing and increment skipPos.
|
// TODO: remove after allowing combination errors
|
||||||
// TODO: Remove this?
|
|
||||||
if (skip) {
|
if (skip) {
|
||||||
mWord[mOutputIndex] = c;
|
return UNRELATED;
|
||||||
incrementOutputIndex();
|
|
||||||
return TRAVERSE_ALL_NOT_ON_TERMINAL;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mWord[mOutputIndex] = c;
|
mWord[mOutputIndex] = c;
|
||||||
// If inputIndex is greater than mInputLength, that means there is no
|
// If inputIndex is greater than mInputLength, that means there is no
|
||||||
// proximity chars. So, we don't need to check proximity.
|
// proximity chars. So, we don't need to check proximity.
|
||||||
if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
|
if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
|
||||||
|
mMatching = true;
|
||||||
charMatched();
|
charMatched();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,7 +257,7 @@ Correction::CorrectionType Correction::processCharAndCalcState(
|
||||||
}
|
}
|
||||||
// Start traversing all nodes after the index exceeds the user typed length
|
// Start traversing all nodes after the index exceeds the user typed length
|
||||||
if (isSameAsUserTypedLength) {
|
if (isSameAsUserTypedLength) {
|
||||||
startTraverseAll();
|
startToTraverseAllNodes();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, we are ready to go to the next character, the next "virtual node".
|
// Finally, we are ready to go to the next character, the next "virtual node".
|
||||||
|
@ -317,6 +327,7 @@ inline static void multiplyRate(const int rate, int *freq) {
|
||||||
// RankingAlgorithm //
|
// RankingAlgorithm //
|
||||||
//////////////////////
|
//////////////////////
|
||||||
|
|
||||||
|
/* static */
|
||||||
int Correction::RankingAlgorithm::calculateFinalFreq(
|
int Correction::RankingAlgorithm::calculateFinalFreq(
|
||||||
const int inputIndex, const int outputIndex,
|
const int inputIndex, const int outputIndex,
|
||||||
const int matchCount, const int freq, const bool sameLength,
|
const int matchCount, const int freq, const bool sameLength,
|
||||||
|
@ -329,6 +340,8 @@ int Correction::RankingAlgorithm::calculateFinalFreq(
|
||||||
const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER;
|
const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER;
|
||||||
const ProximityInfo *proximityInfo = correction->mProximityInfo;
|
const ProximityInfo *proximityInfo = correction->mProximityInfo;
|
||||||
const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
|
const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
|
||||||
|
const unsigned short* word = correction->mWord;
|
||||||
|
const int skippedCount = correction->mSkippedCount;
|
||||||
|
|
||||||
// TODO: Demote by edit distance
|
// TODO: Demote by edit distance
|
||||||
int finalFreq = freq * matchWeight;
|
int finalFreq = freq * matchWeight;
|
||||||
|
@ -382,9 +395,30 @@ int Correction::RankingAlgorithm::calculateFinalFreq(
|
||||||
LOGI("calc: %d, %d", outputIndex, sameLength);
|
LOGI("calc: %d, %d", outputIndex, sameLength);
|
||||||
}
|
}
|
||||||
if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq);
|
if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq);
|
||||||
|
|
||||||
|
// TODO: check excessive count and transposed count
|
||||||
|
/*
|
||||||
|
If the last character of the user input word is the same as the next character
|
||||||
|
of the output word, and also all of characters of the user input are matched
|
||||||
|
to the output word, we'll promote that word a bit because
|
||||||
|
that word can be considered the combination of skipped and matched characters.
|
||||||
|
This means that the 'sm' pattern wins over the 'ma' pattern.
|
||||||
|
e.g.)
|
||||||
|
shel -> shell [mmmma] or [mmmsm]
|
||||||
|
hel -> hello [mmmaa] or [mmsma]
|
||||||
|
m ... matching
|
||||||
|
s ... skipping
|
||||||
|
a ... traversing all
|
||||||
|
*/
|
||||||
|
if (matchCount == inputLength && matchCount >= 2 && skippedCount == 0
|
||||||
|
&& word[matchCount] == word[matchCount - 1]) {
|
||||||
|
multiplyRate(WORDS_WITH_MATCH_SKIP_PROMOTION_RATE, &finalFreq);
|
||||||
|
}
|
||||||
|
|
||||||
return finalFreq;
|
return finalFreq;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
int Correction::RankingAlgorithm::calcFreqForSplitTwoWords(
|
int Correction::RankingAlgorithm::calcFreqForSplitTwoWords(
|
||||||
const int firstFreq, const int secondFreq, const Correction* correction) {
|
const int firstFreq, const int secondFreq, const Correction* correction) {
|
||||||
const int spaceProximityPos = correction->mSpaceProximityPos;
|
const int spaceProximityPos = correction->mSpaceProximityPos;
|
||||||
|
|
|
@ -52,7 +52,6 @@ public:
|
||||||
bool *traverseAllNodes, int *diffs);
|
bool *traverseAllNodes, int *diffs);
|
||||||
int getOutputIndex();
|
int getOutputIndex();
|
||||||
int getInputIndex();
|
int getInputIndex();
|
||||||
bool needsToTraverseAll();
|
|
||||||
|
|
||||||
virtual ~Correction();
|
virtual ~Correction();
|
||||||
int getSpaceProximityPos() const {
|
int getSpaceProximityPos() const {
|
||||||
|
@ -101,45 +100,46 @@ public:
|
||||||
return mCorrectionStates[index].mParentIndex;
|
return mCorrectionStates[index].mParentIndex;
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
void charMatched();
|
inline void charMatched();
|
||||||
void incrementInputIndex();
|
inline void incrementInputIndex();
|
||||||
void incrementOutputIndex();
|
inline void incrementOutputIndex();
|
||||||
void startTraverseAll();
|
inline bool needsToTraverseAllNodes();
|
||||||
|
inline void startToTraverseAllNodes();
|
||||||
|
inline bool isQuote(const unsigned short c);
|
||||||
|
inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal);
|
||||||
|
|
||||||
// TODO: remove
|
// TODO: remove
|
||||||
|
inline void incrementDiffs() {
|
||||||
void incrementDiffs() {
|
|
||||||
++mDiffs;
|
++mDiffs;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int TYPED_LETTER_MULTIPLIER;
|
const int TYPED_LETTER_MULTIPLIER;
|
||||||
const int FULL_WORD_MULTIPLIER;
|
const int FULL_WORD_MULTIPLIER;
|
||||||
|
|
||||||
const ProximityInfo *mProximityInfo;
|
const ProximityInfo *mProximityInfo;
|
||||||
|
|
||||||
int mMaxEditDistance;
|
int mMaxEditDistance;
|
||||||
int mMaxDepth;
|
int mMaxDepth;
|
||||||
int mInputLength;
|
int mInputLength;
|
||||||
int mSkipPos;
|
int mSkipPos;
|
||||||
int mSkippedOutputIndex;
|
|
||||||
int mExcessivePos;
|
int mExcessivePos;
|
||||||
int mTransposedPos;
|
int mTransposedPos;
|
||||||
int mSpaceProximityPos;
|
int mSpaceProximityPos;
|
||||||
int mMissingSpacePos;
|
int mMissingSpacePos;
|
||||||
|
|
||||||
int mMatchedCharCount;
|
|
||||||
int mInputIndex;
|
|
||||||
int mOutputIndex;
|
|
||||||
int mTerminalInputIndex;
|
int mTerminalInputIndex;
|
||||||
int mTerminalOutputIndex;
|
int mTerminalOutputIndex;
|
||||||
int mDiffs;
|
|
||||||
bool mTraverseAllNodes;
|
|
||||||
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
|
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
|
||||||
|
|
||||||
CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];
|
CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];
|
||||||
|
|
||||||
inline bool isQuote(const unsigned short c);
|
// The following member variables are being used as cache values of the correction state.
|
||||||
inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal);
|
int mOutputIndex;
|
||||||
|
int mInputIndex;
|
||||||
|
int mDiffs;
|
||||||
|
int mMatchedCharCount;
|
||||||
|
int mSkippedCount;
|
||||||
|
bool mNeedsToTraverseAllNodes;
|
||||||
|
bool mMatching;
|
||||||
|
bool mSkipping;
|
||||||
|
|
||||||
class RankingAlgorithm {
|
class RankingAlgorithm {
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -23,32 +23,33 @@
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
class CorrectionState {
|
struct CorrectionState {
|
||||||
public:
|
|
||||||
int mParentIndex;
|
int mParentIndex;
|
||||||
int mMatchedCount;
|
|
||||||
int mChildCount;
|
|
||||||
int mInputIndex;
|
|
||||||
int mDiffs;
|
|
||||||
int mSiblingPos;
|
int mSiblingPos;
|
||||||
bool mTraverseAll;
|
uint16_t mChildCount;
|
||||||
|
uint8_t mInputIndex;
|
||||||
|
uint8_t mDiffs;
|
||||||
|
uint8_t mMatchedCount;
|
||||||
|
uint8_t mSkippedCount;
|
||||||
|
bool mMatching;
|
||||||
|
bool mSkipping;
|
||||||
|
bool mNeedsToTraverseAllNodes;
|
||||||
|
|
||||||
inline void init(const int rootPos, const int childCount, const bool traverseAll) {
|
|
||||||
set(-1, 0, childCount, 0, 0, rootPos, traverseAll);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
inline void set(const int parentIndex, const int matchedCount, const int childCount,
|
|
||||||
const int inputIndex, const int diffs, const int siblingPos,
|
|
||||||
const bool traverseAll) {
|
|
||||||
mParentIndex = parentIndex;
|
|
||||||
mMatchedCount = matchedCount;
|
|
||||||
mChildCount = childCount;
|
|
||||||
mInputIndex = inputIndex;
|
|
||||||
mDiffs = diffs;
|
|
||||||
mSiblingPos = siblingPos;
|
|
||||||
mTraverseAll = traverseAll;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline static void initCorrectionState(CorrectionState *state, const int rootPos,
|
||||||
|
const uint16_t childCount, const bool traverseAll) {
|
||||||
|
state->mParentIndex = -1;
|
||||||
|
state->mChildCount = childCount;
|
||||||
|
state->mInputIndex = 0;
|
||||||
|
state->mDiffs = 0;
|
||||||
|
state->mSiblingPos = rootPos;
|
||||||
|
state->mMatchedCount = 0;
|
||||||
|
state->mSkippedCount = 0;
|
||||||
|
state->mMatching = false;
|
||||||
|
state->mSkipping = false;
|
||||||
|
state->mNeedsToTraverseAllNodes = traverseAll;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
#endif // LATINIME_CORRECTION_STATE_H
|
#endif // LATINIME_CORRECTION_STATE_H
|
||||||
|
|
|
@ -160,6 +160,7 @@ static void prof_out(void) {
|
||||||
#define WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE 60
|
#define WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE 60
|
||||||
#define FULL_MATCHED_WORDS_PROMOTION_RATE 120
|
#define FULL_MATCHED_WORDS_PROMOTION_RATE 120
|
||||||
#define WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE 90
|
#define WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE 90
|
||||||
|
#define WORDS_WITH_MATCH_SKIP_PROMOTION_RATE 105
|
||||||
|
|
||||||
// This should be greater than or equal to MAX_WORD_LENGTH defined in BinaryDictionary.java
|
// This should be greater than or equal to MAX_WORD_LENGTH defined in BinaryDictionary.java
|
||||||
// This is only used for the size of array. Not to be used in c functions.
|
// This is only used for the size of array. Not to be used in c functions.
|
||||||
|
|
|
@ -46,6 +46,7 @@ public:
|
||||||
ProximityType getMatchedProximityId(
|
ProximityType getMatchedProximityId(
|
||||||
const int index, const unsigned short c, const bool checkProximityChars) const;
|
const int index, const unsigned short c, const bool checkProximityChars) const;
|
||||||
bool sameAsTyped(const unsigned short *word, int length) const;
|
bool sameAsTyped(const unsigned short *word, int length) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int getStartIndexFromCoordinates(const int x, const int y) const;
|
int getStartIndexFromCoordinates(const int x, const int y) const;
|
||||||
const int MAX_PROXIMITY_CHARS_SIZE;
|
const int MAX_PROXIMITY_CHARS_SIZE;
|
||||||
|
|
Loading…
Reference in New Issue