Move the input index and output index to correction state

Change-Id: Idebdb59143f3367929df6a0475cefe941eb16d01
main
satok 2011-08-03 23:27:32 +09:00
parent bb12dc455b
commit 4e4e74e6b6
4 changed files with 102 additions and 83 deletions

View File

@ -58,32 +58,49 @@ int CorrectionState::getFreqForSplitTwoWords(const int firstFreq, const int seco
return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this); return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this);
} }
int CorrectionState::getFinalFreq(const int inputIndex, const int outputIndex, const int freq) { int CorrectionState::getFinalFreq(const unsigned short *word, const int freq) {
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2) if (mProximityInfo->sameAsTyped(word, mOutputIndex + 1) || mOutputIndex < MIN_SUGGEST_DEPTH) {
: (mInputLength == inputIndex + 1); return -1;
const int matchCount = mMatchedCharCount; }
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == mInputIndex + 2)
: (mInputLength == mInputIndex + 1);
return CorrectionState::RankingAlgorithm::calculateFinalFreq( return CorrectionState::RankingAlgorithm::calculateFinalFreq(
inputIndex, outputIndex, matchCount, freq, sameLength, this); mInputIndex, mOutputIndex, mMatchedCharCount, freq, sameLength, this);
} }
void CorrectionState::initDepth() { void CorrectionState::initProcessState(
mMatchedCharCount = 0; const int matchCount, const int inputIndex, const int outputIndex) {
mMatchedCharCount = matchCount;
mInputIndex = inputIndex;
mOutputIndex = outputIndex;
}
void CorrectionState::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex) {
*matchedCount = mMatchedCharCount;
*inputIndex = mInputIndex;
*outputIndex = mOutputIndex;
} }
void CorrectionState::charMatched() { void CorrectionState::charMatched() {
++mMatchedCharCount; ++mMatchedCharCount;
} }
void CorrectionState::goUpTree(const int matchCount) { // TODO: remove
mMatchedCharCount = matchCount; int CorrectionState::getOutputIndex() {
return mOutputIndex;
} }
void CorrectionState::slideTree(const int matchCount) { // TODO: remove
mMatchedCharCount = matchCount; int CorrectionState::getInputIndex() {
return mInputIndex;
} }
void CorrectionState::goDownTree(int *matchedCount) { void CorrectionState::incrementInputIndex() {
*matchedCount = mMatchedCharCount; ++mInputIndex;
}
void CorrectionState::incrementOutputIndex() {
++mOutputIndex;
} }
CorrectionState::~CorrectionState() { CorrectionState::~CorrectionState() {

View File

@ -28,16 +28,25 @@ class ProximityInfo;
class CorrectionState { class CorrectionState {
public: public:
typedef enum {
ALLOW_ALL,
UNRELATED,
RELATED
} CorrectionStateType;
CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier); CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier);
void initCorrectionState(const ProximityInfo *pi, const int inputLength); void initCorrectionState(const ProximityInfo *pi, const int inputLength);
void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
const int spaceProximityPos, const int missingSpacePos); const int spaceProximityPos, const int missingSpacePos);
void initDepth();
void checkState(); void checkState();
void goUpTree(const int matchCount); void initProcessState(const int matchCount, const int inputIndex, const int outputIndex);
void slideTree(const int matchCount); void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex);
void goDownTree(int *matchedCount);
void charMatched(); void charMatched();
void incrementInputIndex();
void incrementOutputIndex();
int getOutputIndex();
int getInputIndex();
virtual ~CorrectionState(); virtual ~CorrectionState();
int getSkipPos() const { int getSkipPos() const {
return mSkipPos; return mSkipPos;
@ -55,7 +64,7 @@ public:
return mMissingSpacePos; return mMissingSpacePos;
} }
int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq); int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq);
int getFinalFreq(const int inputIndex, const int outputIndex, const int freq); int getFinalFreq(const unsigned short *word, const int freq);
private: private:
@ -71,6 +80,8 @@ private:
int mMissingSpacePos; int mMissingSpacePos;
int mMatchedCharCount; int mMatchedCharCount;
int mInputIndex;
int mOutputIndex;
class RankingAlgorithm { class RankingAlgorithm {
public: public:

View File

@ -363,27 +363,25 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos,
mStackSiblingPos[0] = rootPosition; mStackSiblingPos[0] = rootPosition;
mStackOutputIndex[0] = 0; mStackOutputIndex[0] = 0;
mStackMatchedCount[0] = 0; mStackMatchedCount[0] = 0;
mCorrectionState->initDepth();
// Depth first search // Depth first search
while (depth >= 0) { while (depth >= 0) {
if (mStackChildCount[depth] > 0) { if (mStackChildCount[depth] > 0) {
--mStackChildCount[depth]; --mStackChildCount[depth];
bool traverseAllNodes = mStackTraverseAll[depth]; bool traverseAllNodes = mStackTraverseAll[depth];
int inputIndex = mStackInputIndex[depth];
int diffs = mStackDiffs[depth]; int diffs = mStackDiffs[depth];
int siblingPos = mStackSiblingPos[depth]; int siblingPos = mStackSiblingPos[depth];
int outputIndex = mStackOutputIndex[depth];
int firstChildPos; int firstChildPos;
mCorrectionState->slideTree(mStackMatchedCount[depth]); mCorrectionState->initProcessState(
mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth]);
// depth will never be greater than maxDepth because in that case, // depth will never be greater than maxDepth because in that case,
// needsToTraverseChildrenNodes should be false // needsToTraverseChildrenNodes should be false
const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex, const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos,
maxDepth, traverseAllNodes, inputIndex, diffs, maxDepth, traverseAllNodes, diffs,
mCorrectionState, &childCount, mCorrectionState, &childCount,
&firstChildPos, &traverseAllNodes, &inputIndex, &diffs, &firstChildPos, &traverseAllNodes, &diffs,
&siblingPos, &outputIndex); &siblingPos);
// Update next sibling pos // Update next sibling pos
mStackSiblingPos[depth] = siblingPos; mStackSiblingPos[depth] = siblingPos;
if (needsToTraverseChildrenNodes) { if (needsToTraverseChildrenNodes) {
@ -391,21 +389,15 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos,
++depth; ++depth;
mStackChildCount[depth] = childCount; mStackChildCount[depth] = childCount;
mStackTraverseAll[depth] = traverseAllNodes; mStackTraverseAll[depth] = traverseAllNodes;
mStackInputIndex[depth] = inputIndex;
mStackDiffs[depth] = diffs; mStackDiffs[depth] = diffs;
mStackSiblingPos[depth] = firstChildPos; mStackSiblingPos[depth] = firstChildPos;
mStackOutputIndex[depth] = outputIndex;
int matchedCount; mCorrectionState->getProcessState(&mStackMatchedCount[depth],
mCorrectionState->goDownTree(&matchedCount); &mStackInputIndex[depth], &mStackOutputIndex[depth]);
mStackMatchedCount[depth] = matchedCount;
} else {
mCorrectionState->slideTree(mStackMatchedCount[depth]);
} }
} else { } else {
// Goes to parent sibling node // Goes to parent sibling node
--depth; --depth;
mCorrectionState->goUpTree(mStackMatchedCount[depth]);
} }
} }
} }
@ -446,13 +438,11 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c,
} }
inline void UnigramDictionary::onTerminal(unsigned short int* word, const int outputIndex, inline void UnigramDictionary::onTerminal(
const int inputIndex, const int freq, CorrectionState *correctionState) { unsigned short int* word, const int freq, CorrectionState *correctionState) {
if (!mProximityInfo->sameAsTyped(word, outputIndex + 1) && outputIndex >= MIN_SUGGEST_DEPTH) { const int finalFreq = correctionState->getFinalFreq(word, freq);
const int finalFreq = correctionState->getFinalFreq(inputIndex, outputIndex, freq);
if (finalFreq >= 0) { if (finalFreq >= 0) {
addWord(word, outputIndex + 1, finalFreq); addWord(word, correctionState->getOutputIndex() + 1, finalFreq);
}
} }
} }
@ -667,12 +657,10 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs
// there aren't any more nodes at this level, it merely returns the address of the first byte after // there aren't any more nodes at this level, it merely returns the address of the first byte after
// the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any
// given level, as output into newCount when traversing this level's parent. // given level, as output into newCount when traversing this level's parent.
inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialOutputPos, inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int maxDepth,
const int maxDepth, const bool initialTraverseAllNodes, int inputIndex, const bool initialTraverseAllNodes, const int initialDiffs,
const int initialDiffs,
CorrectionState *correctionState, int *newCount, int *newChildrenPosition, CorrectionState *correctionState, int *newCount, int *newChildrenPosition,
bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs, bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition) {
int *nextSiblingPosition, int *newOutputIndex) {
const int skipPos = correctionState->getSkipPos(); const int skipPos = correctionState->getSkipPos();
const int excessivePos = correctionState->getExcessivePos(); const int excessivePos = correctionState->getExcessivePos();
const int transposedPos = correctionState->getTransposedPos(); const int transposedPos = correctionState->getTransposedPos();
@ -680,9 +668,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
correctionState->checkState(); correctionState->checkState();
} }
int pos = initialPos; int pos = initialPos;
int internalOutputPos = initialOutputPos;
int traverseAllNodes = initialTraverseAllNodes; int traverseAllNodes = initialTraverseAllNodes;
int diffs = initialDiffs; int diffs = initialDiffs;
const int initialInputIndex = correctionState->getInputIndex();
// Flags contain the following information: // Flags contain the following information:
// - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits: // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits:
@ -726,16 +714,18 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
// This has to be done for each virtual char (this forwards the "inputIndex" which // This has to be done for each virtual char (this forwards the "inputIndex" which
// is the index in the user-inputted chars, as read by proximity chars. // is the index in the user-inputted chars, as read by proximity chars.
if (excessivePos == internalOutputPos && inputIndex < mInputLength - 1) { if (excessivePos == correctionState->getOutputIndex()
++inputIndex; && correctionState->getInputIndex() < mInputLength - 1) {
correctionState->incrementInputIndex();
} }
if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, internalOutputPos)) { if (traverseAllNodes || needsToSkipCurrentNode(
mWord[internalOutputPos] = c; c, correctionState->getInputIndex(), skipPos, correctionState->getOutputIndex())) {
mWord[correctionState->getOutputIndex()] = c;
if (traverseAllNodes && isTerminal) { if (traverseAllNodes && isTerminal) {
// The frequency should be here, because we come here only if this is actually // The frequency should be here, because we come here only if this is actually
// a terminal node, and we are on its last char. // a terminal node, and we are on its last char.
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState); onTerminal(mWord, freq, mCorrectionState);
} }
if (!hasChildren) { if (!hasChildren) {
// If we don't have children here, that means we finished processing all // If we don't have children here, that means we finished processing all
@ -750,11 +740,15 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
return false; return false;
} }
} else { } else {
int inputIndexForProximity = inputIndex; int inputIndexForProximity = correctionState->getInputIndex();
if (transposedPos >= 0) { if (transposedPos >= 0) {
if (inputIndex == transposedPos) ++inputIndexForProximity; if (correctionState->getInputIndex() == transposedPos) {
if (inputIndex == (transposedPos + 1)) --inputIndexForProximity; ++inputIndexForProximity;
}
if (correctionState->getInputIndex() == (transposedPos + 1)) {
--inputIndexForProximity;
}
} }
int matchedProximityCharId = mProximityInfo->getMatchedProximityId( int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
@ -775,18 +769,31 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
return false; return false;
} }
mWord[internalOutputPos] = c; mWord[correctionState->getOutputIndex()] = c;
// If inputIndex is greater than mInputLength, that means there is no // If inputIndex is greater than mInputLength, that means there is no
// proximity chars. So, we don't need to check proximity. // proximity chars. So, we don't need to check proximity.
if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
correctionState->charMatched(); correctionState->charMatched();
} }
const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1 const bool isSameAsUserTypedLength = mInputLength
|| (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2); == correctionState->getInputIndex() + 1
|| (excessivePos == mInputLength - 1
&& correctionState->getInputIndex() == mInputLength - 2);
if (isSameAsUserTypedLength && isTerminal) { if (isSameAsUserTypedLength && isTerminal) {
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState); onTerminal(mWord, freq, mCorrectionState);
} }
// Start traversing all nodes after the index exceeds the user typed length
traverseAllNodes = isSameAsUserTypedLength;
diffs = diffs
+ ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0);
// Finally, we are ready to go to the next character, the next "virtual node".
// We should advance the input index.
// We do this in this branch of the 'if traverseAllNodes' because we are still matching
// characters to input; the other branch is not matching them but searching for
// completions, this is why it does not have to do it.
correctionState->incrementInputIndex();
// This character matched the typed character (enough to traverse the node at least) // This character matched the typed character (enough to traverse the node at least)
// so we just evaluated it. Now we should evaluate this virtual node's children - that // so we just evaluated it. Now we should evaluate this virtual node's children - that
// is, if it has any. If it has no children, we're done here - so we skip the end of // is, if it has any. If it has no children, we're done here - so we skip the end of
@ -799,19 +806,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
return false; return false;
} }
// Start traversing all nodes after the index exceeds the user typed length
traverseAllNodes = isSameAsUserTypedLength;
diffs = diffs
+ ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0);
// Finally, we are ready to go to the next character, the next "virtual node".
// We should advance the input index.
// We do this in this branch of the 'if traverseAllNodes' because we are still matching
// characters to input; the other branch is not matching them but searching for
// completions, this is why it does not have to do it.
++inputIndex;
} }
// Optimization: Prune out words that are too long compared to how much was typed. // Optimization: Prune out words that are too long compared to how much was typed.
if (internalOutputPos >= maxDepth || diffs > mMaxEditDistance) { if (correctionState->getOutputIndex() >= maxDepth || diffs > mMaxEditDistance) {
// We are giving up parsing this node and its children. Skip the rest of the node, // We are giving up parsing this node and its children. Skip the rest of the node,
// output the sibling position, and return that we don't want to traverse children. // output the sibling position, and return that we don't want to traverse children.
if (!isLastChar) { if (!isLastChar) {
@ -822,18 +819,18 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
return false; return false;
} }
// Also, the next char is one "virtual node" depth more than this char.
correctionState->incrementOutputIndex();
// Prepare for the next character. Promote the prefetched char to current char - the loop // Prepare for the next character. Promote the prefetched char to current char - the loop
// will take care of prefetching the next. If we finally found our last char, nextc will // will take care of prefetching the next. If we finally found our last char, nextc will
// contain NOT_A_CHARACTER. // contain NOT_A_CHARACTER.
c = nextc; c = nextc;
// Also, the next char is one "virtual node" depth more than this char.
++internalOutputPos;
} while (NOT_A_CHARACTER != c); } while (NOT_A_CHARACTER != c);
// If inputIndex is greater than mInputLength, that means there are no proximity chars. // If inputIndex is greater than mInputLength, that means there are no proximity chars.
// Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength. // Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength.
if (mInputLength <= *newInputIndex) { if (mInputLength <= initialInputIndex) {
traverseAllNodes = true; traverseAllNodes = true;
} }
@ -841,8 +838,6 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
// variables. Output them to the caller. // variables. Output them to the caller.
*newTraverseAllNodes = traverseAllNodes; *newTraverseAllNodes = traverseAllNodes;
*newDiffs = diffs; *newDiffs = diffs;
*newInputIndex = inputIndex;
*newOutputIndex = internalOutputPos;
// Now we finished processing this node, and we want to traverse children. If there are no // Now we finished processing this node, and we want to traverse children. If there are no
// children, we can't come here. // children, we can't come here.

View File

@ -94,18 +94,14 @@ private:
const int inputLength, const int missingSpacePos, CorrectionState *correctionState); const int inputLength, const int missingSpacePos, CorrectionState *correctionState);
void getMistypedSpaceWords( void getMistypedSpaceWords(
const int inputLength, const int spaceProximityPos, CorrectionState *correctionState); const int inputLength, const int spaceProximityPos, CorrectionState *correctionState);
void onTerminal(unsigned short int* word, const int depth, void onTerminal(unsigned short int* word, const int freq, CorrectionState *correctionState);
const int inputIndex, const int freq,
CorrectionState *correctionState);
bool needsToSkipCurrentNode(const unsigned short c, bool needsToSkipCurrentNode(const unsigned short c,
const int inputIndex, const int skipPos, const int depth); const int inputIndex, const int skipPos, const int depth);
// Process a node by considering proximity, missing and excessive character // Process a node by considering proximity, missing and excessive character
bool processCurrentNode(const int initialPos, const int initialDepth, bool processCurrentNode(const int initialPos, const int maxDepth,
const int maxDepth, const bool initialTraverseAllNodes, int inputIndex, const bool initialTraverseAllNodes, const int initialDiffs,
const int initialDiffs,
CorrectionState *correctionState, int *newCount, int *newChildPosition, CorrectionState *correctionState, int *newCount, int *newChildPosition,
bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs, bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition);
int *nextSiblingPosition, int *nextOutputIndex);
int getMostFrequentWordLike(const int startInputIndex, const int inputLength, int getMostFrequentWordLike(const int startInputIndex, const int inputLength,
unsigned short *word); unsigned short *word);
int getMostFrequentWordLikeInner(const uint16_t* const inWord, const int length, int getMostFrequentWordLikeInner(const uint16_t* const inWord, const int length,