Move code related to ranking algorithm to correction_state.cpp

Change-Id: I52b34de45969fef82e46d9c10079c2d45e0b94eb
This commit is contained in:
satok 2011-08-03 02:19:44 +09:00
parent e486290013
commit 0f6c8e8aeb
5 changed files with 93 additions and 79 deletions

View file

@ -58,10 +58,32 @@ int CorrectionState::getFreqForSplitTwoWords(const int firstFreq, const int seco
return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this);
}
int CorrectionState::getFinalFreq(const int inputIndex, const int depth, const int matchWeight,
const int freq, const bool sameLength) {
return CorrectionState::RankingAlgorithm::calculateFinalFreq(inputIndex, depth, matchWeight,
freq, sameLength, this);
int CorrectionState::getFinalFreq(const int inputIndex, const int outputIndex, const int freq) {
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
: (mInputLength == inputIndex + 1);
const int matchCount = mMatchedCharCount;
return CorrectionState::RankingAlgorithm::calculateFinalFreq(
inputIndex, outputIndex, matchCount, freq, sameLength, this);
}
void CorrectionState::initDepth() {
mMatchedCharCount = 0;
}
void CorrectionState::charMatched() {
++mMatchedCharCount;
}
void CorrectionState::goUpTree(const int matchCount) {
mMatchedCharCount = matchCount;
}
void CorrectionState::slideTree(const int matchCount) {
mMatchedCharCount = matchCount;
}
void CorrectionState::goDownTree(int *matchedCount) {
*matchedCount = mMatchedCharCount;
}
CorrectionState::~CorrectionState() {
@ -117,7 +139,8 @@ inline static void multiplyRate(const int rate, int *freq) {
// RankingAlgorithm //
//////////////////////
int CorrectionState::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int depth,
int CorrectionState::RankingAlgorithm::calculateFinalFreq(
const int inputIndex, const int outputIndex,
const int matchCount, const int freq, const bool sameLength,
const CorrectionState* correctionState) {
const int skipPos = correctionState->getSkipPos();
@ -156,10 +179,10 @@ int CorrectionState::RankingAlgorithm::calculateFinalFreq(const int inputIndex,
}
}
int lengthFreq = typedLetterMultiplier;
multiplyIntCapped(powerIntCapped(typedLetterMultiplier, depth), &lengthFreq);
if (lengthFreq == matchWeight) {
multiplyIntCapped(powerIntCapped(typedLetterMultiplier, outputIndex), &lengthFreq);
if ((outputIndex + 1) == matchCount) {
// Full exact match
if (depth > 1) {
if (outputIndex > 1) {
if (DEBUG_DICT) {
LOGI("Found full matched word.");
}
@ -168,7 +191,8 @@ int CorrectionState::RankingAlgorithm::calculateFinalFreq(const int inputIndex,
if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0) {
finalFreq = capped255MultForFullMatchAccentsOrCapitalizationDifference(finalFreq);
}
} else if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0 && depth > 0) {
} else if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0
&& outputIndex > 0) {
// A word with proximity corrections
if (DEBUG_DICT) {
LOGI("Found one proximity correction.");
@ -177,7 +201,7 @@ int CorrectionState::RankingAlgorithm::calculateFinalFreq(const int inputIndex,
multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq);
}
if (DEBUG_DICT) {
LOGI("calc: %d, %d", depth, sameLength);
LOGI("calc: %d, %d", outputIndex, sameLength);
}
if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq);
return finalFreq;

View file

@ -32,7 +32,12 @@ public:
void initCorrectionState(const ProximityInfo *pi, const int inputLength);
void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
const int spaceProximityPos, const int missingSpacePos);
void initDepth();
void checkState();
void goUpTree(const int matchCount);
void slideTree(const int matchCount);
void goDownTree(int *matchedCount);
void charMatched();
virtual ~CorrectionState();
int getSkipPos() const {
return mSkipPos;
@ -50,13 +55,13 @@ public:
return mMissingSpacePos;
}
int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq);
int getFinalFreq(const int inputIndex, const int depth, const int matchWeight, const int freq,
const bool sameLength);
int getFinalFreq(const int inputIndex, const int outputIndex, const int freq);
private:
const int TYPED_LETTER_MULTIPLIER;
const int FULL_WORD_MULTIPLIER;
const ProximityInfo *mProximityInfo;
int mInputLength;
int mSkipPos;
@ -65,6 +70,8 @@ private:
int mSpaceProximityPos;
int mMissingSpacePos;
int mMatchedCharCount;
class RankingAlgorithm {
public:
static int calculateFinalFreq(const int inputIndex, const int depth,

View file

@ -176,9 +176,6 @@ static void prof_out(void) {
#define MIN_USER_TYPED_LENGTH_FOR_MISSING_SPACE_SUGGESTION 3
#define MIN_USER_TYPED_LENGTH_FOR_EXCESSIVE_CHARACTER_SUGGESTION 3
// The size of next letters frequency array. Zero will disable the feature.
#define NEXT_LETTERS_SIZE 0
#define min(a,b) ((a)<(b)?(a):(b))
#endif // LATINIME_DEFINES_H

View file

@ -167,12 +167,6 @@ int UnigramDictionary::getSuggestions(ProximityInfo *proximityInfo, const int *x
LOGI("%s %i", s, mFrequencies[j]);
#endif
}
LOGI("Next letters: ");
for (int k = 0; k < NEXT_LETTERS_SIZE; k++) {
if (mNextLettersFrequency[k] > 0) {
LOGI("%c = %d,", k, mNextLettersFrequency[k]);
}
}
}
PROF_END(20);
PROF_CLOSE;
@ -194,7 +188,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
PROF_END(0);
PROF_START(1);
getSuggestionCandidates(-1, -1, -1, mNextLettersFrequency, NEXT_LETTERS_SIZE, MAX_DEPTH);
getSuggestionCandidates(-1, -1, -1, MAX_DEPTH);
PROF_END(1);
PROF_START(2);
@ -204,7 +198,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
if (DEBUG_DICT) {
LOGI("--- Suggest missing characters %d", i);
}
getSuggestionCandidates(i, -1, -1, NULL, 0, MAX_DEPTH);
getSuggestionCandidates(i, -1, -1, MAX_DEPTH);
}
}
PROF_END(2);
@ -217,7 +211,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
if (DEBUG_DICT) {
LOGI("--- Suggest excessive characters %d", i);
}
getSuggestionCandidates(-1, i, -1, NULL, 0, MAX_DEPTH);
getSuggestionCandidates(-1, i, -1, MAX_DEPTH);
}
}
PROF_END(3);
@ -230,7 +224,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
if (DEBUG_DICT) {
LOGI("--- Suggest transposed characters %d", i);
}
getSuggestionCandidates(-1, -1, i, NULL, 0, mInputLength - 1);
getSuggestionCandidates(-1, -1, i, mInputLength - 1);
}
}
PROF_END(4);
@ -348,8 +342,7 @@ static const char QUOTE = '\'';
static const char SPACE = ' ';
void UnigramDictionary::getSuggestionCandidates(const int skipPos,
const int excessivePos, const int transposedPos, int *nextLetters,
const int nextLettersSize, const int maxDepth) {
const int excessivePos, const int transposedPos, const int maxDepth) {
if (DEBUG_DICT) {
LOGI("getSuggestionCandidates %d", maxDepth);
assert(transposedPos + 1 < mInputLength);
@ -365,29 +358,31 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos,
mStackChildCount[0] = childCount;
mStackTraverseAll[0] = (mInputLength <= 0);
mStackMatchCount[0] = 0;
mStackInputIndex[0] = 0;
mStackDiffs[0] = 0;
mStackSiblingPos[0] = rootPosition;
mStackOutputIndex[0] = 0;
mStackMatchedCount[0] = 0;
mCorrectionState->initDepth();
// Depth first search
while (depth >= 0) {
if (mStackChildCount[depth] > 0) {
--mStackChildCount[depth];
bool traverseAllNodes = mStackTraverseAll[depth];
int matchCount = mStackMatchCount[depth];
int inputIndex = mStackInputIndex[depth];
int diffs = mStackDiffs[depth];
int siblingPos = mStackSiblingPos[depth];
int outputIndex = mStackOutputIndex[depth];
int firstChildPos;
mCorrectionState->slideTree(mStackMatchedCount[depth]);
// depth will never be greater than maxDepth because in that case,
// needsToTraverseChildrenNodes should be false
const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex,
maxDepth, traverseAllNodes, matchCount, inputIndex, diffs,
nextLetters, nextLettersSize, mCorrectionState, &childCount,
&firstChildPos, &traverseAllNodes, &matchCount, &inputIndex, &diffs,
maxDepth, traverseAllNodes, inputIndex, diffs,
mCorrectionState, &childCount,
&firstChildPos, &traverseAllNodes, &inputIndex, &diffs,
&siblingPos, &outputIndex);
// Update next sibling pos
mStackSiblingPos[depth] = siblingPos;
@ -396,15 +391,21 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos,
++depth;
mStackChildCount[depth] = childCount;
mStackTraverseAll[depth] = traverseAllNodes;
mStackMatchCount[depth] = matchCount;
mStackInputIndex[depth] = inputIndex;
mStackDiffs[depth] = diffs;
mStackSiblingPos[depth] = firstChildPos;
mStackOutputIndex[depth] = outputIndex;
int matchedCount;
mCorrectionState->goDownTree(&matchedCount);
mStackMatchedCount[depth] = matchedCount;
} else {
mCorrectionState->slideTree(mStackMatchedCount[depth]);
}
} else {
// Goes to parent sibling node
--depth;
mCorrectionState->goUpTree(mStackMatchedCount[depth]);
}
}
}
@ -445,24 +446,13 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c,
}
inline void UnigramDictionary::onTerminal(unsigned short int* word, const int depth,
const uint8_t* const root, const uint8_t flags, const int pos,
const int inputIndex, const int matchCount, const int freq, const bool sameLength,
int* nextLetters, const int nextLettersSize, CorrectionState *correctionState) {
const int skipPos = correctionState->getSkipPos();
const bool isSameAsTyped = sameLength ? mProximityInfo->sameAsTyped(word, depth + 1) : false;
if (isSameAsTyped) return;
if (depth >= MIN_SUGGEST_DEPTH) {
const int finalFreq = correctionState->getFinalFreq(inputIndex, depth, matchCount,
freq, sameLength);
if (!isSameAsTyped)
addWord(word, depth + 1, finalFreq);
}
if (sameLength && depth >= mInputLength && skipPos < 0) {
registerNextLetter(word[mInputLength], nextLetters, nextLettersSize);
inline void UnigramDictionary::onTerminal(unsigned short int* word, const int outputIndex,
const int inputIndex, const int freq, CorrectionState *correctionState) {
if (!mProximityInfo->sameAsTyped(word, outputIndex + 1) && outputIndex >= MIN_SUGGEST_DEPTH) {
const int finalFreq = correctionState->getFinalFreq(inputIndex, outputIndex, freq);
if (finalFreq >= 0) {
addWord(word, outputIndex + 1, finalFreq);
}
}
}
@ -677,11 +667,11 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs
// there aren't any more nodes at this level, it merely returns the address of the first byte after
// the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any
// given level, as output into newCount when traversing this level's parent.
inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth,
const int maxDepth, const bool initialTraverseAllNodes, int matchCount, int inputIndex,
const int initialDiffs, int *nextLetters, const int nextLettersSize,
inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialOutputPos,
const int maxDepth, const bool initialTraverseAllNodes, int inputIndex,
const int initialDiffs,
CorrectionState *correctionState, int *newCount, int *newChildrenPosition,
bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs,
bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs,
int *nextSiblingPosition, int *newOutputIndex) {
const int skipPos = correctionState->getSkipPos();
const int excessivePos = correctionState->getExcessivePos();
@ -690,7 +680,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
correctionState->checkState();
}
int pos = initialPos;
int depth = initialDepth;
int internalOutputPos = initialOutputPos;
int traverseAllNodes = initialTraverseAllNodes;
int diffs = initialDiffs;
@ -736,15 +726,16 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
// This has to be done for each virtual char (this forwards the "inputIndex" which
// is the index in the user-inputted chars, as read by proximity chars.
if (excessivePos == depth && inputIndex < mInputLength - 1) ++inputIndex;
if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, depth)) {
mWord[depth] = c;
if (excessivePos == internalOutputPos && inputIndex < mInputLength - 1) {
++inputIndex;
}
if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, internalOutputPos)) {
mWord[internalOutputPos] = c;
if (traverseAllNodes && isTerminal) {
// The frequency should be here, because we come here only if this is actually
// a terminal node, and we are on its last char.
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchCount,
freq, false, nextLetters, nextLettersSize, mCorrectionState);
onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState);
}
if (!hasChildren) {
// If we don't have children here, that means we finished processing all
@ -784,18 +775,17 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
return false;
}
mWord[depth] = c;
mWord[internalOutputPos] = c;
// If inputIndex is greater than mInputLength, that means there is no
// proximity chars. So, we don't need to check proximity.
if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
++matchCount;
correctionState->charMatched();
}
const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1
|| (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2);
if (isSameAsUserTypedLength && isTerminal) {
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchCount,
freq, true, nextLetters, nextLettersSize, mCorrectionState);
onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState);
}
// This character matched the typed character (enough to traverse the node at least)
// so we just evaluated it. Now we should evaluate this virtual node's children - that
@ -821,7 +811,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
++inputIndex;
}
// Optimization: Prune out words that are too long compared to how much was typed.
if (depth >= maxDepth || diffs > mMaxEditDistance) {
if (internalOutputPos >= maxDepth || diffs > mMaxEditDistance) {
// We are giving up parsing this node and its children. Skip the rest of the node,
// output the sibling position, and return that we don't want to traverse children.
if (!isLastChar) {
@ -838,7 +828,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
// contain NOT_A_CHARACTER.
c = nextc;
// Also, the next char is one "virtual node" depth more than this char.
++depth;
++internalOutputPos;
} while (NOT_A_CHARACTER != c);
// If inputIndex is greater than mInputLength, that means there are no proximity chars.
@ -850,10 +840,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
// All the output values that are purely computation by this function are held in local
// variables. Output them to the caller.
*newTraverseAllNodes = traverseAllNodes;
*newMatchRate = matchCount;
*newDiffs = diffs;
*newInputIndex = inputIndex;
*newOutputIndex = depth;
*newOutputIndex = internalOutputPos;
// Now we finished processing this node, and we want to traverse children. If there are no
// children, we can't come here.

View file

@ -87,8 +87,7 @@ private:
const int *ycoordinates, const int *codes, const int codesSize,
unsigned short *outWords, int *frequencies);
void getSuggestionCandidates(const int skipPos, const int excessivePos,
const int transposedPos, int *nextLetters, const int nextLettersSize,
const int maxDepth);
const int transposedPos, const int maxDepth);
bool addWord(unsigned short *word, int length, int frequency);
void getSplitTwoWordsSuggestion(const int inputLength, CorrectionState *correctionState);
void getMissingSpaceWords(
@ -96,17 +95,16 @@ private:
void getMistypedSpaceWords(
const int inputLength, const int spaceProximityPos, CorrectionState *correctionState);
void onTerminal(unsigned short int* word, const int depth,
const uint8_t* const root, const uint8_t flags, const int pos,
const int inputIndex, const int matchWeight, const int freq, const bool sameLength,
int* nextLetters, const int nextLettersSize, CorrectionState *correctionState);
const int inputIndex, const int freq,
CorrectionState *correctionState);
bool needsToSkipCurrentNode(const unsigned short c,
const int inputIndex, const int skipPos, const int depth);
// Process a node by considering proximity, missing and excessive character
bool processCurrentNode(const int initialPos, const int initialDepth,
const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex,
const int initialDiffs, int *nextLetters, const int nextLettersSize,
const int maxDepth, const bool initialTraverseAllNodes, int inputIndex,
const int initialDiffs,
CorrectionState *correctionState, int *newCount, int *newChildPosition,
bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs,
bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs,
int *nextSiblingPosition, int *nextOutputIndex);
int getMostFrequentWordLike(const int startInputIndex, const int inputLength,
unsigned short *word);
@ -142,14 +140,13 @@ private:
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
int mMaxEditDistance;
int mStackMatchedCount[MAX_WORD_LENGTH_INTERNAL];
int mStackChildCount[MAX_WORD_LENGTH_INTERNAL];
bool mStackTraverseAll[MAX_WORD_LENGTH_INTERNAL];
int mStackMatchCount[MAX_WORD_LENGTH_INTERNAL];
int mStackInputIndex[MAX_WORD_LENGTH_INTERNAL];
int mStackDiffs[MAX_WORD_LENGTH_INTERNAL];
int mStackSiblingPos[MAX_WORD_LENGTH_INTERNAL];
int mStackOutputIndex[MAX_WORD_LENGTH_INTERNAL];
int mNextLettersFrequency[NEXT_LETTERS_SIZE];
};
} // namespace latinime