Move scoring part to the correction state

Change-Id: I2dc4a0869636fce5526f48b3a6267b6bdf61dbfb
main
satok 2011-08-04 18:31:57 +09:00
parent 2e2906bc17
commit 8876b75ca1
4 changed files with 248 additions and 185 deletions

View File

@ -25,13 +25,31 @@
namespace latinime { namespace latinime {
//////////////////////
// inline functions //
//////////////////////
static const char QUOTE = '\'';
inline bool CorrectionState::needsToSkipCurrentNode(const unsigned short c) {
const unsigned short userTypedChar = mProximityInfo->getPrimaryCharAt(mInputIndex);
// Skip the ' or other letter and continue deeper
return (c == QUOTE && userTypedChar != QUOTE) || mSkipPos == mOutputIndex;
}
/////////////////////
// CorrectionState //
/////////////////////
CorrectionState::CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier) CorrectionState::CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier)
: TYPED_LETTER_MULTIPLIER(typedLetterMultiplier), FULL_WORD_MULTIPLIER(fullWordMultiplier) { : TYPED_LETTER_MULTIPLIER(typedLetterMultiplier), FULL_WORD_MULTIPLIER(fullWordMultiplier) {
} }
void CorrectionState::initCorrectionState(const ProximityInfo *pi, const int inputLength) { void CorrectionState::initCorrectionState(const ProximityInfo *pi, const int inputLength,
const int maxDepth) {
mProximityInfo = pi; mProximityInfo = pi;
mInputLength = inputLength; mInputLength = inputLength;
mMaxDepth = maxDepth;
mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
} }
void CorrectionState::setCorrectionParams(const int skipPos, const int excessivePos, void CorrectionState::setCorrectionParams(const int skipPos, const int excessivePos,
@ -58,27 +76,37 @@ int CorrectionState::getFreqForSplitTwoWords(const int firstFreq, const int seco
return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this); return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this);
} }
int CorrectionState::getFinalFreq(const unsigned short *word, const int freq) { int CorrectionState::getFinalFreq(const int freq, unsigned short **word, int *wordLength) {
if (mProximityInfo->sameAsTyped(word, mOutputIndex + 1) || mOutputIndex < MIN_SUGGEST_DEPTH) { const int outputIndex = mOutputIndex - 1;
const int inputIndex = (mCurrentStateType == TRAVERSE_ALL_ON_TERMINAL
|| mCurrentStateType == TRAVERSE_ALL_NOT_ON_TERMINAL) ? mInputIndex : mInputIndex - 1;
*wordLength = outputIndex + 1;
if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) {
return -1; return -1;
} }
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == mInputIndex + 2) *word = mWord;
: (mInputLength == mInputIndex + 1); const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
: (mInputLength == inputIndex + 1);
return CorrectionState::RankingAlgorithm::calculateFinalFreq( return CorrectionState::RankingAlgorithm::calculateFinalFreq(
mInputIndex, mOutputIndex, mMatchedCharCount, freq, sameLength, this); inputIndex, outputIndex, mMatchedCharCount, freq, sameLength, this);
} }
void CorrectionState::initProcessState( void CorrectionState::initProcessState(const int matchCount, const int inputIndex,
const int matchCount, const int inputIndex, const int outputIndex) { const int outputIndex, const bool traverseAllNodes, const int diffs) {
mMatchedCharCount = matchCount; mMatchedCharCount = matchCount;
mInputIndex = inputIndex; mInputIndex = inputIndex;
mOutputIndex = outputIndex; mOutputIndex = outputIndex;
mTraverseAllNodes = traverseAllNodes;
mDiffs = diffs;
} }
void CorrectionState::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex) { void CorrectionState::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex,
bool *traverseAllNodes, int *diffs) {
*matchedCount = mMatchedCharCount; *matchedCount = mMatchedCharCount;
*inputIndex = mInputIndex; *inputIndex = mInputIndex;
*outputIndex = mOutputIndex; *outputIndex = mOutputIndex;
*traverseAllNodes = mTraverseAllNodes;
*diffs = mDiffs;
} }
void CorrectionState::charMatched() { void CorrectionState::charMatched() {
@ -95,6 +123,11 @@ int CorrectionState::getInputIndex() {
return mInputIndex; return mInputIndex;
} }
// TODO: remove
bool CorrectionState::needsToTraverseAll() {
return mTraverseAllNodes;
}
void CorrectionState::incrementInputIndex() { void CorrectionState::incrementInputIndex() {
++mInputIndex; ++mInputIndex;
} }
@ -103,6 +136,86 @@ void CorrectionState::incrementOutputIndex() {
++mOutputIndex; ++mOutputIndex;
} }
void CorrectionState::startTraverseAll() {
mTraverseAllNodes = true;
}
bool CorrectionState::needsToPrune() const {
return (mOutputIndex - 1 >= (mTransposedPos >= 0 ? mInputLength - 1 : mMaxDepth)
|| mDiffs > mMaxEditDistance);
}
CorrectionState::CorrectionStateType CorrectionState::processCharAndCalcState(
const int32_t c, const bool isTerminal) {
mCurrentStateType = NOT_ON_TERMINAL;
// This has to be done for each virtual char (this forwards the "inputIndex" which
// is the index in the user-inputted chars, as read by proximity chars.
if (mExcessivePos == mOutputIndex && mInputIndex < mInputLength - 1) {
incrementInputIndex();
}
if (mTraverseAllNodes || needsToSkipCurrentNode(c)) {
mWord[mOutputIndex] = c;
if (needsToTraverseAll() && isTerminal) {
mCurrentStateType = TRAVERSE_ALL_ON_TERMINAL;
} else {
mCurrentStateType = TRAVERSE_ALL_NOT_ON_TERMINAL;
}
} else {
int inputIndexForProximity = mInputIndex;
if (mTransposedPos >= 0) {
if (mInputIndex == mTransposedPos) {
++inputIndexForProximity;
}
if (mInputIndex == (mTransposedPos + 1)) {
--inputIndexForProximity;
}
}
int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
inputIndexForProximity, c, this);
if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
mCurrentStateType = UNRELATED;
return mCurrentStateType;
}
mWord[mOutputIndex] = c;
// If inputIndex is greater than mInputLength, that means there is no
// proximity chars. So, we don't need to check proximity.
if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
charMatched();
}
if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) {
incrementDiffs();
}
const bool isSameAsUserTypedLength = mInputLength
== getInputIndex() + 1
|| (mExcessivePos == mInputLength - 1
&& getInputIndex() == mInputLength - 2);
if (isSameAsUserTypedLength && isTerminal) {
mCurrentStateType = ON_TERMINAL;
}
// Start traversing all nodes after the index exceeds the user typed length
if (isSameAsUserTypedLength) {
startTraverseAll();
}
// Finally, we are ready to go to the next character, the next "virtual node".
// We should advance the input index.
// We do this in this branch of the 'if traverseAllNodes' because we are still matching
// characters to input; the other branch is not matching them but searching for
// completions, this is why it does not have to do it.
incrementInputIndex();
}
// Also, the next char is one "virtual node" depth more than this char.
incrementOutputIndex();
return mCurrentStateType;
}
CorrectionState::~CorrectionState() { CorrectionState::~CorrectionState() {
} }

View File

@ -29,49 +29,76 @@ class CorrectionState {
public: public:
typedef enum { typedef enum {
ALLOW_ALL, TRAVERSE_ALL_ON_TERMINAL,
TRAVERSE_ALL_NOT_ON_TERMINAL,
UNRELATED, UNRELATED,
RELATED ON_TERMINAL,
NOT_ON_TERMINAL
} CorrectionStateType; } CorrectionStateType;
CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier); CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier);
void initCorrectionState(const ProximityInfo *pi, const int inputLength); void initCorrectionState(
const ProximityInfo *pi, const int inputLength, const int maxWordLength);
void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
const int spaceProximityPos, const int missingSpacePos); const int spaceProximityPos, const int missingSpacePos);
void checkState(); void checkState();
void initProcessState(const int matchCount, const int inputIndex, const int outputIndex); void initProcessState(const int matchCount, const int inputIndex, const int outputIndex,
void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex); const bool traverseAllNodes, const int diffs);
void charMatched(); void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex,
void incrementInputIndex(); bool *traverseAllNodes, int *diffs);
void incrementOutputIndex();
int getOutputIndex(); int getOutputIndex();
int getInputIndex(); int getInputIndex();
bool needsToTraverseAll();
virtual ~CorrectionState(); virtual ~CorrectionState();
int getSkipPos() const {
return mSkipPos;
}
int getExcessivePos() const {
return mExcessivePos;
}
int getTransposedPos() const {
return mTransposedPos;
}
int getSpaceProximityPos() const { int getSpaceProximityPos() const {
return mSpaceProximityPos; return mSpaceProximityPos;
} }
int getMissingSpacePos() const { int getMissingSpacePos() const {
return mMissingSpacePos; return mMissingSpacePos;
} }
int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq);
int getFinalFreq(const unsigned short *word, const int freq);
int getSkipPos() const {
return mSkipPos;
}
int getExcessivePos() const {
return mExcessivePos;
}
int getTransposedPos() const {
return mTransposedPos;
}
bool needsToPrune() const;
int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq);
int getFinalFreq(const int freq, unsigned short **word, int* wordLength);
CorrectionStateType processCharAndCalcState(const int32_t c, const bool isTerminal);
int getDiffs() const {
return mDiffs;
}
private: private:
void charMatched();
void incrementInputIndex();
void incrementOutputIndex();
void startTraverseAll();
// TODO: remove
void incrementDiffs() {
++mDiffs;
}
const int TYPED_LETTER_MULTIPLIER; const int TYPED_LETTER_MULTIPLIER;
const int FULL_WORD_MULTIPLIER; const int FULL_WORD_MULTIPLIER;
const ProximityInfo *mProximityInfo; const ProximityInfo *mProximityInfo;
int mMaxEditDistance;
int mMaxDepth;
int mInputLength; int mInputLength;
int mSkipPos; int mSkipPos;
int mExcessivePos; int mExcessivePos;
@ -82,6 +109,12 @@ private:
int mMatchedCharCount; int mMatchedCharCount;
int mInputIndex; int mInputIndex;
int mOutputIndex; int mOutputIndex;
int mDiffs;
bool mTraverseAllNodes;
CorrectionStateType mCurrentStateType;
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
inline bool needsToSkipCurrentNode(const unsigned short c);
class RankingAlgorithm { class RankingAlgorithm {
public: public:

View File

@ -181,14 +181,14 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
PROF_START(0); PROF_START(0);
initSuggestions( initSuggestions(
proximityInfo, xcoordinates, ycoordinates, codes, codesSize, outWords, frequencies); proximityInfo, xcoordinates, ycoordinates, codes, codesSize, outWords, frequencies);
mCorrectionState->initCorrectionState(mProximityInfo, mInputLength);
if (DEBUG_DICT) assert(codesSize == mInputLength); if (DEBUG_DICT) assert(codesSize == mInputLength);
const int MAX_DEPTH = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH); const int maxDepth = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH);
mCorrectionState->initCorrectionState(mProximityInfo, mInputLength, maxDepth);
PROF_END(0); PROF_END(0);
PROF_START(1); PROF_START(1);
getSuggestionCandidates(-1, -1, -1, MAX_DEPTH); getSuggestionCandidates(-1, -1, -1);
PROF_END(1); PROF_END(1);
PROF_START(2); PROF_START(2);
@ -198,7 +198,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
if (DEBUG_DICT) { if (DEBUG_DICT) {
LOGI("--- Suggest missing characters %d", i); LOGI("--- Suggest missing characters %d", i);
} }
getSuggestionCandidates(i, -1, -1, MAX_DEPTH); getSuggestionCandidates(i, -1, -1);
} }
} }
PROF_END(2); PROF_END(2);
@ -211,7 +211,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
if (DEBUG_DICT) { if (DEBUG_DICT) {
LOGI("--- Suggest excessive characters %d", i); LOGI("--- Suggest excessive characters %d", i);
} }
getSuggestionCandidates(-1, i, -1, MAX_DEPTH); getSuggestionCandidates(-1, i, -1);
} }
} }
PROF_END(3); PROF_END(3);
@ -224,7 +224,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
if (DEBUG_DICT) { if (DEBUG_DICT) {
LOGI("--- Suggest transposed characters %d", i); LOGI("--- Suggest transposed characters %d", i);
} }
getSuggestionCandidates(-1, -1, i, mInputLength - 1); getSuggestionCandidates(-1, -1, i);
} }
} }
PROF_END(4); PROF_END(4);
@ -272,7 +272,6 @@ void UnigramDictionary::initSuggestions(ProximityInfo *proximityInfo, const int
mFrequencies = frequencies; mFrequencies = frequencies;
mOutputChars = outWords; mOutputChars = outWords;
mInputLength = codesSize; mInputLength = codesSize;
mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
proximityInfo->setInputParams(codes, codesSize); proximityInfo->setInputParams(codes, codesSize);
mProximityInfo = proximityInfo; mProximityInfo = proximityInfo;
} }
@ -342,9 +341,8 @@ static const char QUOTE = '\'';
static const char SPACE = ' '; static const char SPACE = ' ';
void UnigramDictionary::getSuggestionCandidates(const int skipPos, void UnigramDictionary::getSuggestionCandidates(const int skipPos,
const int excessivePos, const int transposedPos, const int maxDepth) { const int excessivePos, const int transposedPos) {
if (DEBUG_DICT) { if (DEBUG_DICT) {
LOGI("getSuggestionCandidates %d", maxDepth);
assert(transposedPos + 1 < mInputLength); assert(transposedPos + 1 < mInputLength);
assert(excessivePos < mInputLength); assert(excessivePos < mInputLength);
assert(missingPos < mInputLength); assert(missingPos < mInputLength);
@ -368,32 +366,26 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos,
while (depth >= 0) { while (depth >= 0) {
if (mStackChildCount[depth] > 0) { if (mStackChildCount[depth] > 0) {
--mStackChildCount[depth]; --mStackChildCount[depth];
bool traverseAllNodes = mStackTraverseAll[depth];
int diffs = mStackDiffs[depth];
int siblingPos = mStackSiblingPos[depth]; int siblingPos = mStackSiblingPos[depth];
int firstChildPos; int firstChildPos;
mCorrectionState->initProcessState( mCorrectionState->initProcessState(
mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth]); mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth],
mStackTraverseAll[depth], mStackDiffs[depth]);
// depth will never be greater than maxDepth because in that case,
// needsToTraverseChildrenNodes should be false // needsToTraverseChildrenNodes should be false
const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos,
maxDepth, traverseAllNodes, diffs, mCorrectionState, &childCount, &firstChildPos, &siblingPos);
mCorrectionState, &childCount,
&firstChildPos, &traverseAllNodes, &diffs,
&siblingPos);
// Update next sibling pos // Update next sibling pos
mStackSiblingPos[depth] = siblingPos; mStackSiblingPos[depth] = siblingPos;
if (needsToTraverseChildrenNodes) { if (needsToTraverseChildrenNodes) {
// Goes to child node // Goes to child node
++depth; ++depth;
mStackChildCount[depth] = childCount; mStackChildCount[depth] = childCount;
mStackTraverseAll[depth] = traverseAllNodes;
mStackDiffs[depth] = diffs;
mStackSiblingPos[depth] = firstChildPos; mStackSiblingPos[depth] = firstChildPos;
mCorrectionState->getProcessState(&mStackMatchedCount[depth], mCorrectionState->getProcessState(&mStackMatchedCount[depth],
&mStackInputIndex[depth], &mStackOutputIndex[depth]); &mStackInputIndex[depth], &mStackOutputIndex[depth],
&mStackTraverseAll[depth], &mStackDiffs[depth]);
} }
} else { } else {
// Goes to parent sibling node // Goes to parent sibling node
@ -437,12 +429,12 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c,
return (c == QUOTE && userTypedChar != QUOTE) || skipPos == depth; return (c == QUOTE && userTypedChar != QUOTE) || skipPos == depth;
} }
inline void UnigramDictionary::onTerminal(const int freq, CorrectionState *correctionState) {
inline void UnigramDictionary::onTerminal( int wordLength;
unsigned short int* word, const int freq, CorrectionState *correctionState) { unsigned short* wordPointer;
const int finalFreq = correctionState->getFinalFreq(word, freq); const int finalFreq = correctionState->getFinalFreq(freq, &wordPointer, &wordLength);
if (finalFreq >= 0) { if (finalFreq >= 0) {
addWord(word, correctionState->getOutputIndex() + 1, finalFreq); addWord(wordPointer, wordLength, finalFreq);
} }
} }
@ -657,20 +649,13 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs
// there aren't any more nodes at this level, it merely returns the address of the first byte after // there aren't any more nodes at this level, it merely returns the address of the first byte after
// the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any
// given level, as output into newCount when traversing this level's parent. // given level, as output into newCount when traversing this level's parent.
inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int maxDepth, inline bool UnigramDictionary::processCurrentNode(const int initialPos,
const bool initialTraverseAllNodes, const int initialDiffs, CorrectionState *correctionState, int *newCount,
CorrectionState *correctionState, int *newCount, int *newChildrenPosition, int *newChildrenPosition, int *nextSiblingPosition) {
bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition) {
const int skipPos = correctionState->getSkipPos();
const int excessivePos = correctionState->getExcessivePos();
const int transposedPos = correctionState->getTransposedPos();
if (DEBUG_DICT) { if (DEBUG_DICT) {
correctionState->checkState(); correctionState->checkState();
} }
int pos = initialPos; int pos = initialPos;
int traverseAllNodes = initialTraverseAllNodes;
int diffs = initialDiffs;
const int initialInputIndex = correctionState->getInputIndex();
// Flags contain the following information: // Flags contain the following information:
// - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits: // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits:
@ -682,6 +667,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
// - FLAG_HAS_BIGRAMS: whether this node has bigrams or not // - FLAG_HAS_BIGRAMS: whether this node has bigrams or not
const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(DICT_ROOT, &pos); const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(DICT_ROOT, &pos);
const bool hasMultipleChars = (0 != (FLAG_HAS_MULTIPLE_CHARS & flags)); const bool hasMultipleChars = (0 != (FLAG_HAS_MULTIPLE_CHARS & flags));
const bool isTerminalNode = (0 != (FLAG_IS_TERMINAL & flags));
bool needsToInvokeOnTerminal = false;
// This gets only ONE character from the stream. Next there will be: // This gets only ONE character from the stream. Next there will be:
// if FLAG_HAS_MULTIPLE CHARS: the other characters of the same node // if FLAG_HAS_MULTIPLE CHARS: the other characters of the same node
@ -707,53 +695,14 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
const bool isLastChar = (NOT_A_CHARACTER == nextc); const bool isLastChar = (NOT_A_CHARACTER == nextc);
// If there are more chars in this nodes, then this virtual node is not a terminal. // If there are more chars in this nodes, then this virtual node is not a terminal.
// If we are on the last char, this virtual node is a terminal if this node is. // If we are on the last char, this virtual node is a terminal if this node is.
const bool isTerminal = isLastChar && (0 != (FLAG_IS_TERMINAL & flags)); const bool isTerminal = isLastChar && isTerminalNode;
// If there are more chars in this node, then this virtual node has children.
// If we are on the last char, this virtual node has children if this node has.
const bool hasChildren = (!isLastChar) || BinaryFormat::hasChildrenInFlags(flags);
// This has to be done for each virtual char (this forwards the "inputIndex" which CorrectionState::CorrectionStateType stateType = correctionState->processCharAndCalcState(
// is the index in the user-inputted chars, as read by proximity chars. c, isTerminal);
if (excessivePos == correctionState->getOutputIndex() if (stateType == CorrectionState::TRAVERSE_ALL_ON_TERMINAL
&& correctionState->getInputIndex() < mInputLength - 1) { || stateType == CorrectionState::ON_TERMINAL) {
correctionState->incrementInputIndex(); needsToInvokeOnTerminal = true;
} } else if (stateType == CorrectionState::UNRELATED) {
if (traverseAllNodes || needsToSkipCurrentNode(
c, correctionState->getInputIndex(), skipPos, correctionState->getOutputIndex())) {
mWord[correctionState->getOutputIndex()] = c;
if (traverseAllNodes && isTerminal) {
// The frequency should be here, because we come here only if this is actually
// a terminal node, and we are on its last char.
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
onTerminal(mWord, freq, mCorrectionState);
}
if (!hasChildren) {
// If we don't have children here, that means we finished processing all
// characters of this node (we are on the last virtual node), AND we are in
// traverseAllNodes mode, which means we are searching for *completions*. We
// should skip the frequency if we have a terminal, and report the position
// of the next sibling. We don't have to return other values because we are
// returning false, as in "don't traverse children".
if (isTerminal) pos = BinaryFormat::skipFrequency(flags, pos);
*nextSiblingPosition =
BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
return false;
}
} else {
int inputIndexForProximity = correctionState->getInputIndex();
if (transposedPos >= 0) {
if (correctionState->getInputIndex() == transposedPos) {
++inputIndexForProximity;
}
if (correctionState->getInputIndex() == (transposedPos + 1)) {
--inputIndexForProximity;
}
}
int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
inputIndexForProximity, c, mCorrectionState);
if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
// We found that this is an unrelated character, so we should give up traversing // We found that this is an unrelated character, so we should give up traversing
// this node and its children entirely. // this node and its children entirely.
// However we may not be on the last virtual node yet so we skip the remaining // However we may not be on the last virtual node yet so we skip the remaining
@ -769,30 +718,24 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
return false; return false;
} }
mWord[correctionState->getOutputIndex()] = c;
// If inputIndex is greater than mInputLength, that means there is no // Prepare for the next character. Promote the prefetched char to current char - the loop
// proximity chars. So, we don't need to check proximity. // will take care of prefetching the next. If we finally found our last char, nextc will
if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { // contain NOT_A_CHARACTER.
correctionState->charMatched(); c = nextc;
} } while (NOT_A_CHARACTER != c);
const bool isSameAsUserTypedLength = mInputLength
== correctionState->getInputIndex() + 1 if (isTerminalNode) {
|| (excessivePos == mInputLength - 1 if (needsToInvokeOnTerminal) {
&& correctionState->getInputIndex() == mInputLength - 2); // The frequency should be here, because we come here only if this is actually
if (isSameAsUserTypedLength && isTerminal) { // a terminal node, and we are on its last char.
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
onTerminal(mWord, freq, mCorrectionState); onTerminal(freq, mCorrectionState);
} }
// Start traversing all nodes after the index exceeds the user typed length
traverseAllNodes = isSameAsUserTypedLength; // If there are more chars in this node, then this virtual node has children.
diffs = diffs // If we are on the last char, this virtual node has children if this node has.
+ ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0); const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags);
// Finally, we are ready to go to the next character, the next "virtual node".
// We should advance the input index.
// We do this in this branch of the 'if traverseAllNodes' because we are still matching
// characters to input; the other branch is not matching them but searching for
// completions, this is why it does not have to do it.
correctionState->incrementInputIndex();
// This character matched the typed character (enough to traverse the node at least) // This character matched the typed character (enough to traverse the node at least)
// so we just evaluated it. Now we should evaluate this virtual node's children - that // so we just evaluated it. Now we should evaluate this virtual node's children - that
@ -806,40 +749,16 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
return false; return false;
} }
}
// Optimization: Prune out words that are too long compared to how much was typed. // Optimization: Prune out words that are too long compared to how much was typed.
if (isTerminal if (correctionState->needsToPrune()) {
&& (correctionState->getOutputIndex() >= maxDepth || diffs > mMaxEditDistance)) {
// We are giving up parsing this node and its children. Skip the rest of the node,
// output the sibling position, and return that we don't want to traverse children.
if (!isLastChar) {
pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos);
}
pos = BinaryFormat::skipFrequency(flags, pos); pos = BinaryFormat::skipFrequency(flags, pos);
*nextSiblingPosition = *nextSiblingPosition =
BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
return false; return false;
} }
// Also, the next char is one "virtual node" depth more than this char.
correctionState->incrementOutputIndex();
// Prepare for the next character. Promote the prefetched char to current char - the loop
// will take care of prefetching the next. If we finally found our last char, nextc will
// contain NOT_A_CHARACTER.
c = nextc;
} while (NOT_A_CHARACTER != c);
// If inputIndex is greater than mInputLength, that means there are no proximity chars.
// Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength.
if (mInputLength <= initialInputIndex) {
traverseAllNodes = true;
} }
// All the output values that are purely computation by this function are held in local
// variables. Output them to the caller.
*newTraverseAllNodes = traverseAllNodes;
*newDiffs = diffs;
// Now we finished processing this node, and we want to traverse children. If there are no // Now we finished processing this node, and we want to traverse children. If there are no
// children, we can't come here. // children, we can't come here.
assert(BinaryFormat::hasChildrenInFlags(flags)); assert(BinaryFormat::hasChildrenInFlags(flags));

View File

@ -87,21 +87,20 @@ private:
const int *ycoordinates, const int *codes, const int codesSize, const int *ycoordinates, const int *codes, const int codesSize,
unsigned short *outWords, int *frequencies); unsigned short *outWords, int *frequencies);
void getSuggestionCandidates(const int skipPos, const int excessivePos, void getSuggestionCandidates(const int skipPos, const int excessivePos,
const int transposedPos, const int maxDepth); const int transposedPos);
bool addWord(unsigned short *word, int length, int frequency); bool addWord(unsigned short *word, int length, int frequency);
void getSplitTwoWordsSuggestion(const int inputLength, CorrectionState *correctionState); void getSplitTwoWordsSuggestion(const int inputLength, CorrectionState *correctionState);
void getMissingSpaceWords( void getMissingSpaceWords(
const int inputLength, const int missingSpacePos, CorrectionState *correctionState); const int inputLength, const int missingSpacePos, CorrectionState *correctionState);
void getMistypedSpaceWords( void getMistypedSpaceWords(
const int inputLength, const int spaceProximityPos, CorrectionState *correctionState); const int inputLength, const int spaceProximityPos, CorrectionState *correctionState);
void onTerminal(unsigned short int* word, const int freq, CorrectionState *correctionState); void onTerminal(const int freq, CorrectionState *correctionState);
bool needsToSkipCurrentNode(const unsigned short c, bool needsToSkipCurrentNode(const unsigned short c,
const int inputIndex, const int skipPos, const int depth); const int inputIndex, const int skipPos, const int depth);
// Process a node by considering proximity, missing and excessive character // Process a node by considering proximity, missing and excessive character
bool processCurrentNode(const int initialPos, const int maxDepth, bool processCurrentNode(const int initialPos,
const bool initialTraverseAllNodes, const int initialDiffs, CorrectionState *correctionState, int *newCount,
CorrectionState *correctionState, int *newCount, int *newChildPosition, int *newChildPosition, int *nextSiblingPosition);
bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition);
int getMostFrequentWordLike(const int startInputIndex, const int inputLength, int getMostFrequentWordLike(const int startInputIndex, const int inputLength,
unsigned short *word); unsigned short *word);
int getMostFrequentWordLikeInner(const uint16_t* const inWord, const int length, int getMostFrequentWordLikeInner(const uint16_t* const inWord, const int length,
@ -134,7 +133,6 @@ private:
int mInputLength; int mInputLength;
// MAX_WORD_LENGTH_INTERNAL must be bigger than MAX_WORD_LENGTH // MAX_WORD_LENGTH_INTERNAL must be bigger than MAX_WORD_LENGTH
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
int mMaxEditDistance;
int mStackMatchedCount[MAX_WORD_LENGTH_INTERNAL]; int mStackMatchedCount[MAX_WORD_LENGTH_INTERNAL];
int mStackChildCount[MAX_WORD_LENGTH_INTERNAL]; int mStackChildCount[MAX_WORD_LENGTH_INTERNAL];