Improve skip correction algorithm

Change-Id: Ife45e3886137d60a4e903d4c6f7a9ef20c7e705a
This commit is contained in:
satok 2011-08-10 22:19:33 +09:00
parent c359d75ca7
commit 635f68e822
5 changed files with 94 additions and 57 deletions

View file

@ -49,12 +49,11 @@ void Correction::initCorrection(const ProximityInfo *pi, const int inputLength,
mInputLength = inputLength;
mMaxDepth = maxDepth;
mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
mSkippedOutputIndex = -1;
}
void Correction::initCorrectionState(
const int rootPos, const int childCount, const bool traverseAll) {
mCorrectionStates[0].init(rootPos, childCount, traverseAll);
latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll);
}
void Correction::setCorrectionParams(const int skipPos, const int excessivePos,
@ -88,6 +87,12 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen
if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) {
return -1;
}
// TODO: Remove this
if (mSkipPos >= 0 && mSkippedCount <= 0) {
return -1;
}
*word = mWord;
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
: (mInputLength == inputIndex + 1);
@ -103,8 +108,11 @@ bool Correction::initProcessState(const int outputIndex) {
--(mCorrectionStates[outputIndex].mChildCount);
mMatchedCharCount = mCorrectionStates[outputIndex].mMatchedCount;
mInputIndex = mCorrectionStates[outputIndex].mInputIndex;
mTraverseAllNodes = mCorrectionStates[outputIndex].mTraverseAll;
mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes;
mDiffs = mCorrectionStates[outputIndex].mDiffs;
mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount;
mSkipping = false;
mMatching = false;
return true;
}
@ -131,8 +139,8 @@ int Correction::getInputIndex() {
}
// TODO: remove
bool Correction::needsToTraverseAll() {
return mTraverseAllNodes;
bool Correction::needsToTraverseAllNodes() {
return mNeedsToTraverseAllNodes;
}
void Correction::incrementInputIndex() {
@ -146,12 +154,15 @@ void Correction::incrementOutputIndex() {
mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
mCorrectionStates[mOutputIndex].mMatchedCount = mMatchedCharCount;
mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
mCorrectionStates[mOutputIndex].mTraverseAll = mTraverseAllNodes;
mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;
mCorrectionStates[mOutputIndex].mDiffs = mDiffs;
mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;
mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
mCorrectionStates[mOutputIndex].mMatching = mMatching;
}
void Correction::startTraverseAll() {
mTraverseAllNodes = true;
void Correction::startToTraverseAllNodes() {
mNeedsToTraverseAllNodes = true;
}
bool Correction::needsToPrune() const {
@ -162,7 +173,7 @@ bool Correction::needsToPrune() const {
Correction::CorrectionType Correction::processSkipChar(
const int32_t c, const bool isTerminal) {
mWord[mOutputIndex] = c;
if (needsToTraverseAll() && isTerminal) {
if (needsToTraverseAllNodes() && isTerminal) {
mTerminalInputIndex = mInputIndex;
mTerminalOutputIndex = mOutputIndex;
incrementOutputIndex();
@ -185,9 +196,10 @@ Correction::CorrectionType Correction::processCharAndCalcState(
bool skip = false;
if (mSkipPos >= 0) {
skip = mSkipPos == mOutputIndex;
mSkipping = true;
}
if (mTraverseAllNodes || isQuote(c)) {
if (mNeedsToTraverseAllNodes || isQuote(c)) {
return processSkipChar(c, isTerminal);
} else {
int inputIndexForProximity = mInputIndex;
@ -210,25 +222,23 @@ Correction::CorrectionType Correction::processCharAndCalcState(
if (unrelated) {
if (skip) {
// Skip this letter and continue deeper
mSkippedOutputIndex = mOutputIndex;
++mSkippedCount;
return processSkipChar(c, isTerminal);
} else {
return UNRELATED;
}
}
// No need to skip. Finish traversing and increment skipPos.
// TODO: Remove this?
// TODO: remove after allowing combination errors
if (skip) {
mWord[mOutputIndex] = c;
incrementOutputIndex();
return TRAVERSE_ALL_NOT_ON_TERMINAL;
return UNRELATED;
}
mWord[mOutputIndex] = c;
// If inputIndex is greater than mInputLength, that means there is no
// proximity chars. So, we don't need to check proximity.
if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
mMatching = true;
charMatched();
}
@ -247,7 +257,7 @@ Correction::CorrectionType Correction::processCharAndCalcState(
}
// Start traversing all nodes after the index exceeds the user typed length
if (isSameAsUserTypedLength) {
startTraverseAll();
startToTraverseAllNodes();
}
// Finally, we are ready to go to the next character, the next "virtual node".
@ -317,6 +327,7 @@ inline static void multiplyRate(const int rate, int *freq) {
// RankingAlgorithm //
//////////////////////
/* static */
int Correction::RankingAlgorithm::calculateFinalFreq(
const int inputIndex, const int outputIndex,
const int matchCount, const int freq, const bool sameLength,
@ -329,6 +340,8 @@ int Correction::RankingAlgorithm::calculateFinalFreq(
const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER;
const ProximityInfo *proximityInfo = correction->mProximityInfo;
const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
const unsigned short* word = correction->mWord;
const int skippedCount = correction->mSkippedCount;
// TODO: Demote by edit distance
int finalFreq = freq * matchWeight;
@ -382,9 +395,30 @@ int Correction::RankingAlgorithm::calculateFinalFreq(
LOGI("calc: %d, %d", outputIndex, sameLength);
}
if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq);
// TODO: check excessive count and transposed count
/*
If the last character of the user input word is the same as the next character
of the output word, and also all of characters of the user input are matched
to the output word, we'll promote that word a bit because
that word can be considered the combination of skipped and matched characters.
This means that the 'sm' pattern wins over the 'ma' pattern.
e.g.)
shel -> shell [mmmma] or [mmmsm]
hel -> hello [mmmaa] or [mmsma]
m ... matching
s ... skipping
a ... traversing all
*/
if (matchCount == inputLength && matchCount >= 2 && skippedCount == 0
&& word[matchCount] == word[matchCount - 1]) {
multiplyRate(WORDS_WITH_MATCH_SKIP_PROMOTION_RATE, &finalFreq);
}
return finalFreq;
}
/* static */
int Correction::RankingAlgorithm::calcFreqForSplitTwoWords(
const int firstFreq, const int secondFreq, const Correction* correction) {
const int spaceProximityPos = correction->mSpaceProximityPos;

View file

@ -52,7 +52,6 @@ public:
bool *traverseAllNodes, int *diffs);
int getOutputIndex();
int getInputIndex();
bool needsToTraverseAll();
virtual ~Correction();
int getSpaceProximityPos() const {
@ -101,45 +100,46 @@ public:
return mCorrectionStates[index].mParentIndex;
}
private:
void charMatched();
void incrementInputIndex();
void incrementOutputIndex();
void startTraverseAll();
inline void charMatched();
inline void incrementInputIndex();
inline void incrementOutputIndex();
inline bool needsToTraverseAllNodes();
inline void startToTraverseAllNodes();
inline bool isQuote(const unsigned short c);
inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal);
// TODO: remove
void incrementDiffs() {
inline void incrementDiffs() {
++mDiffs;
}
const int TYPED_LETTER_MULTIPLIER;
const int FULL_WORD_MULTIPLIER;
const ProximityInfo *mProximityInfo;
int mMaxEditDistance;
int mMaxDepth;
int mInputLength;
int mSkipPos;
int mSkippedOutputIndex;
int mExcessivePos;
int mTransposedPos;
int mSpaceProximityPos;
int mMissingSpacePos;
int mMatchedCharCount;
int mInputIndex;
int mOutputIndex;
int mTerminalInputIndex;
int mTerminalOutputIndex;
int mDiffs;
bool mTraverseAllNodes;
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];
inline bool isQuote(const unsigned short c);
inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal);
// The following member variables are being used as cache values of the correction state.
int mOutputIndex;
int mInputIndex;
int mDiffs;
int mMatchedCharCount;
int mSkippedCount;
bool mNeedsToTraverseAllNodes;
bool mMatching;
bool mSkipping;
class RankingAlgorithm {
public:

View file

@ -23,32 +23,33 @@
namespace latinime {
class CorrectionState {
public:
struct CorrectionState {
int mParentIndex;
int mMatchedCount;
int mChildCount;
int mInputIndex;
int mDiffs;
int mSiblingPos;
bool mTraverseAll;
uint16_t mChildCount;
uint8_t mInputIndex;
uint8_t mDiffs;
uint8_t mMatchedCount;
uint8_t mSkippedCount;
bool mMatching;
bool mSkipping;
bool mNeedsToTraverseAllNodes;
inline void init(const int rootPos, const int childCount, const bool traverseAll) {
set(-1, 0, childCount, 0, 0, rootPos, traverseAll);
}
private:
inline void set(const int parentIndex, const int matchedCount, const int childCount,
const int inputIndex, const int diffs, const int siblingPos,
const bool traverseAll) {
mParentIndex = parentIndex;
mMatchedCount = matchedCount;
mChildCount = childCount;
mInputIndex = inputIndex;
mDiffs = diffs;
mSiblingPos = siblingPos;
mTraverseAll = traverseAll;
}
};
inline static void initCorrectionState(CorrectionState *state, const int rootPos,
const uint16_t childCount, const bool traverseAll) {
state->mParentIndex = -1;
state->mChildCount = childCount;
state->mInputIndex = 0;
state->mDiffs = 0;
state->mSiblingPos = rootPos;
state->mMatchedCount = 0;
state->mSkippedCount = 0;
state->mMatching = false;
state->mSkipping = false;
state->mNeedsToTraverseAllNodes = traverseAll;
}
} // namespace latinime
#endif // LATINIME_CORRECTION_STATE_H

View file

@ -160,6 +160,7 @@ static void prof_out(void) {
#define WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE 60
#define FULL_MATCHED_WORDS_PROMOTION_RATE 120
#define WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE 90
#define WORDS_WITH_MATCH_SKIP_PROMOTION_RATE 105
// This should be greater than or equal to MAX_WORD_LENGTH defined in BinaryDictionary.java
// This is only used for the size of array. Not to be used in c functions.

View file

@ -46,6 +46,7 @@ public:
ProximityType getMatchedProximityId(
const int index, const unsigned short c, const bool checkProximityChars) const;
bool sameAsTyped(const unsigned short *word, int length) const;
private:
int getStartIndexFromCoordinates(const int x, const int y) const;
const int MAX_PROXIMITY_CHARS_SIZE;