Improve skip correction algorithm

Change-Id: Ife45e3886137d60a4e903d4c6f7a9ef20c7e705a
main
satok 2011-08-10 22:19:33 +09:00
parent c359d75ca7
commit 635f68e822
5 changed files with 94 additions and 57 deletions

View File

@ -49,12 +49,11 @@ void Correction::initCorrection(const ProximityInfo *pi, const int inputLength,
mInputLength = inputLength; mInputLength = inputLength;
mMaxDepth = maxDepth; mMaxDepth = maxDepth;
mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2; mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
mSkippedOutputIndex = -1;
} }
void Correction::initCorrectionState( void Correction::initCorrectionState(
const int rootPos, const int childCount, const bool traverseAll) { const int rootPos, const int childCount, const bool traverseAll) {
mCorrectionStates[0].init(rootPos, childCount, traverseAll); latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll);
} }
void Correction::setCorrectionParams(const int skipPos, const int excessivePos, void Correction::setCorrectionParams(const int skipPos, const int excessivePos,
@ -88,6 +87,12 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen
if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) { if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) {
return -1; return -1;
} }
// TODO: Remove this
if (mSkipPos >= 0 && mSkippedCount <= 0) {
return -1;
}
*word = mWord; *word = mWord;
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2) const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
: (mInputLength == inputIndex + 1); : (mInputLength == inputIndex + 1);
@ -103,8 +108,11 @@ bool Correction::initProcessState(const int outputIndex) {
--(mCorrectionStates[outputIndex].mChildCount); --(mCorrectionStates[outputIndex].mChildCount);
mMatchedCharCount = mCorrectionStates[outputIndex].mMatchedCount; mMatchedCharCount = mCorrectionStates[outputIndex].mMatchedCount;
mInputIndex = mCorrectionStates[outputIndex].mInputIndex; mInputIndex = mCorrectionStates[outputIndex].mInputIndex;
mTraverseAllNodes = mCorrectionStates[outputIndex].mTraverseAll; mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes;
mDiffs = mCorrectionStates[outputIndex].mDiffs; mDiffs = mCorrectionStates[outputIndex].mDiffs;
mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount;
mSkipping = false;
mMatching = false;
return true; return true;
} }
@ -131,8 +139,8 @@ int Correction::getInputIndex() {
} }
// TODO: remove // TODO: remove
bool Correction::needsToTraverseAll() { bool Correction::needsToTraverseAllNodes() {
return mTraverseAllNodes; return mNeedsToTraverseAllNodes;
} }
void Correction::incrementInputIndex() { void Correction::incrementInputIndex() {
@ -146,12 +154,15 @@ void Correction::incrementOutputIndex() {
mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
mCorrectionStates[mOutputIndex].mMatchedCount = mMatchedCharCount; mCorrectionStates[mOutputIndex].mMatchedCount = mMatchedCharCount;
mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
mCorrectionStates[mOutputIndex].mTraverseAll = mTraverseAllNodes; mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;
mCorrectionStates[mOutputIndex].mDiffs = mDiffs; mCorrectionStates[mOutputIndex].mDiffs = mDiffs;
mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;
mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
mCorrectionStates[mOutputIndex].mMatching = mMatching;
} }
void Correction::startTraverseAll() { void Correction::startToTraverseAllNodes() {
mTraverseAllNodes = true; mNeedsToTraverseAllNodes = true;
} }
bool Correction::needsToPrune() const { bool Correction::needsToPrune() const {
@ -162,7 +173,7 @@ bool Correction::needsToPrune() const {
Correction::CorrectionType Correction::processSkipChar( Correction::CorrectionType Correction::processSkipChar(
const int32_t c, const bool isTerminal) { const int32_t c, const bool isTerminal) {
mWord[mOutputIndex] = c; mWord[mOutputIndex] = c;
if (needsToTraverseAll() && isTerminal) { if (needsToTraverseAllNodes() && isTerminal) {
mTerminalInputIndex = mInputIndex; mTerminalInputIndex = mInputIndex;
mTerminalOutputIndex = mOutputIndex; mTerminalOutputIndex = mOutputIndex;
incrementOutputIndex(); incrementOutputIndex();
@ -185,9 +196,10 @@ Correction::CorrectionType Correction::processCharAndCalcState(
bool skip = false; bool skip = false;
if (mSkipPos >= 0) { if (mSkipPos >= 0) {
skip = mSkipPos == mOutputIndex; skip = mSkipPos == mOutputIndex;
mSkipping = true;
} }
if (mTraverseAllNodes || isQuote(c)) { if (mNeedsToTraverseAllNodes || isQuote(c)) {
return processSkipChar(c, isTerminal); return processSkipChar(c, isTerminal);
} else { } else {
int inputIndexForProximity = mInputIndex; int inputIndexForProximity = mInputIndex;
@ -210,25 +222,23 @@ Correction::CorrectionType Correction::processCharAndCalcState(
if (unrelated) { if (unrelated) {
if (skip) { if (skip) {
// Skip this letter and continue deeper // Skip this letter and continue deeper
mSkippedOutputIndex = mOutputIndex; ++mSkippedCount;
return processSkipChar(c, isTerminal); return processSkipChar(c, isTerminal);
} else { } else {
return UNRELATED; return UNRELATED;
} }
} }
// No need to skip. Finish traversing and increment skipPos. // TODO: remove after allowing combination errors
// TODO: Remove this?
if (skip) { if (skip) {
mWord[mOutputIndex] = c; return UNRELATED;
incrementOutputIndex();
return TRAVERSE_ALL_NOT_ON_TERMINAL;
} }
mWord[mOutputIndex] = c; mWord[mOutputIndex] = c;
// If inputIndex is greater than mInputLength, that means there is no // If inputIndex is greater than mInputLength, that means there is no
// proximity chars. So, we don't need to check proximity. // proximity chars. So, we don't need to check proximity.
if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
mMatching = true;
charMatched(); charMatched();
} }
@ -247,7 +257,7 @@ Correction::CorrectionType Correction::processCharAndCalcState(
} }
// Start traversing all nodes after the index exceeds the user typed length // Start traversing all nodes after the index exceeds the user typed length
if (isSameAsUserTypedLength) { if (isSameAsUserTypedLength) {
startTraverseAll(); startToTraverseAllNodes();
} }
// Finally, we are ready to go to the next character, the next "virtual node". // Finally, we are ready to go to the next character, the next "virtual node".
@ -317,6 +327,7 @@ inline static void multiplyRate(const int rate, int *freq) {
// RankingAlgorithm // // RankingAlgorithm //
////////////////////// //////////////////////
/* static */
int Correction::RankingAlgorithm::calculateFinalFreq( int Correction::RankingAlgorithm::calculateFinalFreq(
const int inputIndex, const int outputIndex, const int inputIndex, const int outputIndex,
const int matchCount, const int freq, const bool sameLength, const int matchCount, const int freq, const bool sameLength,
@ -329,6 +340,8 @@ int Correction::RankingAlgorithm::calculateFinalFreq(
const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER; const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER;
const ProximityInfo *proximityInfo = correction->mProximityInfo; const ProximityInfo *proximityInfo = correction->mProximityInfo;
const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
const unsigned short* word = correction->mWord;
const int skippedCount = correction->mSkippedCount;
// TODO: Demote by edit distance // TODO: Demote by edit distance
int finalFreq = freq * matchWeight; int finalFreq = freq * matchWeight;
@ -382,9 +395,30 @@ int Correction::RankingAlgorithm::calculateFinalFreq(
LOGI("calc: %d, %d", outputIndex, sameLength); LOGI("calc: %d, %d", outputIndex, sameLength);
} }
if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq); if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq);
// TODO: check excessive count and transposed count
/*
If the last character of the user input word is the same as the next character
of the output word, and also all of characters of the user input are matched
to the output word, we'll promote that word a bit because
that word can be considered the combination of skipped and matched characters.
This means that the 'sm' pattern wins over the 'ma' pattern.
e.g.)
shel -> shell [mmmma] or [mmmsm]
hel -> hello [mmmaa] or [mmsma]
m ... matching
s ... skipping
a ... traversing all
*/
if (matchCount == inputLength && matchCount >= 2 && skippedCount == 0
&& word[matchCount] == word[matchCount - 1]) {
multiplyRate(WORDS_WITH_MATCH_SKIP_PROMOTION_RATE, &finalFreq);
}
return finalFreq; return finalFreq;
} }
/* static */
int Correction::RankingAlgorithm::calcFreqForSplitTwoWords( int Correction::RankingAlgorithm::calcFreqForSplitTwoWords(
const int firstFreq, const int secondFreq, const Correction* correction) { const int firstFreq, const int secondFreq, const Correction* correction) {
const int spaceProximityPos = correction->mSpaceProximityPos; const int spaceProximityPos = correction->mSpaceProximityPos;

View File

@ -52,7 +52,6 @@ public:
bool *traverseAllNodes, int *diffs); bool *traverseAllNodes, int *diffs);
int getOutputIndex(); int getOutputIndex();
int getInputIndex(); int getInputIndex();
bool needsToTraverseAll();
virtual ~Correction(); virtual ~Correction();
int getSpaceProximityPos() const { int getSpaceProximityPos() const {
@ -101,45 +100,46 @@ public:
return mCorrectionStates[index].mParentIndex; return mCorrectionStates[index].mParentIndex;
} }
private: private:
void charMatched(); inline void charMatched();
void incrementInputIndex(); inline void incrementInputIndex();
void incrementOutputIndex(); inline void incrementOutputIndex();
void startTraverseAll(); inline bool needsToTraverseAllNodes();
inline void startToTraverseAllNodes();
inline bool isQuote(const unsigned short c);
inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal);
// TODO: remove // TODO: remove
inline void incrementDiffs() {
void incrementDiffs() {
++mDiffs; ++mDiffs;
} }
const int TYPED_LETTER_MULTIPLIER; const int TYPED_LETTER_MULTIPLIER;
const int FULL_WORD_MULTIPLIER; const int FULL_WORD_MULTIPLIER;
const ProximityInfo *mProximityInfo; const ProximityInfo *mProximityInfo;
int mMaxEditDistance; int mMaxEditDistance;
int mMaxDepth; int mMaxDepth;
int mInputLength; int mInputLength;
int mSkipPos; int mSkipPos;
int mSkippedOutputIndex;
int mExcessivePos; int mExcessivePos;
int mTransposedPos; int mTransposedPos;
int mSpaceProximityPos; int mSpaceProximityPos;
int mMissingSpacePos; int mMissingSpacePos;
int mMatchedCharCount;
int mInputIndex;
int mOutputIndex;
int mTerminalInputIndex; int mTerminalInputIndex;
int mTerminalOutputIndex; int mTerminalOutputIndex;
int mDiffs;
bool mTraverseAllNodes;
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL]; CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];
inline bool isQuote(const unsigned short c); // The following member variables are being used as cache values of the correction state.
inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal); int mOutputIndex;
int mInputIndex;
int mDiffs;
int mMatchedCharCount;
int mSkippedCount;
bool mNeedsToTraverseAllNodes;
bool mMatching;
bool mSkipping;
class RankingAlgorithm { class RankingAlgorithm {
public: public:

View File

@ -23,32 +23,33 @@
namespace latinime { namespace latinime {
class CorrectionState { struct CorrectionState {
public:
int mParentIndex; int mParentIndex;
int mMatchedCount;
int mChildCount;
int mInputIndex;
int mDiffs;
int mSiblingPos; int mSiblingPos;
bool mTraverseAll; uint16_t mChildCount;
uint8_t mInputIndex;
uint8_t mDiffs;
uint8_t mMatchedCount;
uint8_t mSkippedCount;
bool mMatching;
bool mSkipping;
bool mNeedsToTraverseAllNodes;
inline void init(const int rootPos, const int childCount, const bool traverseAll) {
set(-1, 0, childCount, 0, 0, rootPos, traverseAll);
}
private:
inline void set(const int parentIndex, const int matchedCount, const int childCount,
const int inputIndex, const int diffs, const int siblingPos,
const bool traverseAll) {
mParentIndex = parentIndex;
mMatchedCount = matchedCount;
mChildCount = childCount;
mInputIndex = inputIndex;
mDiffs = diffs;
mSiblingPos = siblingPos;
mTraverseAll = traverseAll;
}
}; };
inline static void initCorrectionState(CorrectionState *state, const int rootPos,
const uint16_t childCount, const bool traverseAll) {
state->mParentIndex = -1;
state->mChildCount = childCount;
state->mInputIndex = 0;
state->mDiffs = 0;
state->mSiblingPos = rootPos;
state->mMatchedCount = 0;
state->mSkippedCount = 0;
state->mMatching = false;
state->mSkipping = false;
state->mNeedsToTraverseAllNodes = traverseAll;
}
} // namespace latinime } // namespace latinime
#endif // LATINIME_CORRECTION_STATE_H #endif // LATINIME_CORRECTION_STATE_H

View File

@ -160,6 +160,7 @@ static void prof_out(void) {
#define WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE 60 #define WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE 60
#define FULL_MATCHED_WORDS_PROMOTION_RATE 120 #define FULL_MATCHED_WORDS_PROMOTION_RATE 120
#define WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE 90 #define WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE 90
#define WORDS_WITH_MATCH_SKIP_PROMOTION_RATE 105
// This should be greater than or equal to MAX_WORD_LENGTH defined in BinaryDictionary.java // This should be greater than or equal to MAX_WORD_LENGTH defined in BinaryDictionary.java
// This is only used for the size of array. Not to be used in c functions. // This is only used for the size of array. Not to be used in c functions.

View File

@ -46,6 +46,7 @@ public:
ProximityType getMatchedProximityId( ProximityType getMatchedProximityId(
const int index, const unsigned short c, const bool checkProximityChars) const; const int index, const unsigned short c, const bool checkProximityChars) const;
bool sameAsTyped(const unsigned short *word, int length) const; bool sameAsTyped(const unsigned short *word, int length) const;
private: private:
int getStartIndexFromCoordinates(const int x, const int y) const; int getStartIndexFromCoordinates(const int x, const int y) const;
const int MAX_PROXIMITY_CHARS_SIZE; const int MAX_PROXIMITY_CHARS_SIZE;