From 2df3060883c7535029c7dfbbb4f7b05935d796ae Mon Sep 17 00:00:00 2001 From: satok Date: Fri, 15 Jul 2011 13:49:00 +0900 Subject: [PATCH] Add correction state Change-Id: I0d281cede1590893bd1def005cf83c9431d12750 --- native/Android.mk | 1 + native/src/correction_state.cpp | 52 ++++++++++++++ native/src/correction_state.h | 52 ++++++++++++++ native/src/proximity_info.cpp | 8 ++- native/src/proximity_info.h | 5 +- native/src/unigram_dictionary.cpp | 111 +++++++++++------------------- native/src/unigram_dictionary.h | 34 ++++----- 7 files changed, 166 insertions(+), 97 deletions(-) create mode 100644 native/src/correction_state.cpp create mode 100644 native/src/correction_state.h diff --git a/native/Android.mk b/native/Android.mk index bc246a990..63cd68cf6 100644 --- a/native/Android.mk +++ b/native/Android.mk @@ -17,6 +17,7 @@ LOCAL_SRC_FILES := \ jni/jni_common.cpp \ src/bigram_dictionary.cpp \ src/char_utils.cpp \ + src/correction_state.cpp \ src/dictionary.cpp \ src/proximity_info.cpp \ src/unigram_dictionary.cpp diff --git a/native/src/correction_state.cpp b/native/src/correction_state.cpp new file mode 100644 index 000000000..aa5efce40 --- /dev/null +++ b/native/src/correction_state.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2011 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#define LOG_TAG "LatinIME: correction_state.cpp" + +#include "correction_state.h" + +namespace latinime { + +CorrectionState::CorrectionState() { +} + +void CorrectionState::setCorrectionParams(const ProximityInfo *pi, const int inputLength, + const int skipPos, const int excessivePos, const int transposedPos) { + mProximityInfo = pi; + mSkipPos = skipPos; + mExcessivePos = excessivePos; + mTransposedPos = transposedPos; +} + +void CorrectionState::checkState() { + if (DEBUG_DICT) { + int inputCount = 0; + if (mSkipPos >= 0) ++inputCount; + if (mExcessivePos >= 0) ++inputCount; + if (mTransposedPos >= 0) ++inputCount; + // TODO: remove this assert + assert(inputCount <= 1); + } +} + +CorrectionState::~CorrectionState() { +} + +} // namespace latinime diff --git a/native/src/correction_state.h b/native/src/correction_state.h new file mode 100644 index 000000000..5b7392590 --- /dev/null +++ b/native/src/correction_state.h @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2011 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_CORRECTION_STATE_H +#define LATINIME_CORRECTION_STATE_H + +#include + +#include "defines.h" + +namespace latinime { + +class ProximityInfo; + +class CorrectionState { +public: + CorrectionState(); + void setCorrectionParams(const ProximityInfo *pi, const int inputLength, const int skipPos, + const int excessivePos, const int transposedPos); + void checkState(); + virtual ~CorrectionState(); + int getSkipPos() const { + return mSkipPos; + } + int getExcessivePos() const { + return mExcessivePos; + } + int getTransposedPos() const { + return mTransposedPos; + } +private: + const ProximityInfo *mProximityInfo; + int mInputLength; + int mSkipPos; + int mExcessivePos; + int mTransposedPos; +}; +} // namespace latinime +#endif // LATINIME_CORRECTION_INFO_H diff --git a/native/src/proximity_info.cpp b/native/src/proximity_info.cpp index c45393f18..bed92cf9e 100644 --- a/native/src/proximity_info.cpp +++ b/native/src/proximity_info.cpp @@ -78,7 +78,7 @@ unsigned short ProximityInfo::getPrimaryCharAt(const int index) const { return getProximityCharsAt(index)[0]; } -bool ProximityInfo::existsCharInProximityAt(const int index, const int c) const { +inline bool ProximityInfo::existsCharInProximityAt(const int index, const int c) const { const int *chars = getProximityCharsAt(index); int i = 0; while (chars[i] > 0 && i < MAX_PROXIMITY_CHARS_SIZE) { @@ -114,8 +114,10 @@ bool ProximityInfo::existsAdjacentProximityChars(const int index) const { // in their list. The non-accented version of the character should be considered // "close", but not the other keys close to the non-accented version. ProximityInfo::ProximityType ProximityInfo::getMatchedProximityId( - const int index, const unsigned short c, const int skipPos, - const int excessivePos, const int transposedPos) const { + const int index, const unsigned short c, CorrectionState *correctionState) const { + const int skipPos = correctionState->getSkipPos(); + const int excessivePos = correctionState->getExcessivePos(); + const int transposedPos = correctionState->getTransposedPos(); const int *currentChars = getProximityCharsAt(index); const unsigned short baseLowerC = Dictionary::toBaseLowerCase(c); diff --git a/native/src/proximity_info.h b/native/src/proximity_info.h index 435a60151..b28191d01 100644 --- a/native/src/proximity_info.h +++ b/native/src/proximity_info.h @@ -23,6 +23,8 @@ namespace latinime { +class CorrectionState; + class ProximityInfo { public: typedef enum { // Used as a return value for character comparison @@ -42,8 +44,7 @@ public: bool existsCharInProximityAt(const int index, const int c) const; bool existsAdjacentProximityChars(const int index) const; ProximityType getMatchedProximityId( - const int index, const unsigned short c, const int skipPos, - const int excessivePos, const int transposedPos) const; + const int index, const unsigned short c, CorrectionState *correctionState) const; bool sameAsTyped(const unsigned short *word, int length) const; private: int getStartIndexFromCoordinates(const int x, const int y) const; diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index bccd37a61..27d15157c 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -58,9 +58,12 @@ UnigramDictionary::UnigramDictionary(const uint8_t* const streamStart, int typed if (DEBUG_DICT) { LOGI("UnigramDictionary - constructor"); } + mCorrectionState = new CorrectionState(); } -UnigramDictionary::~UnigramDictionary() {} +UnigramDictionary::~UnigramDictionary() { + delete mCorrectionState; +} static inline unsigned int getCodesBufferSize(const int* codes, const int codesSize, const int MAX_PROXIMITY_CHARS) { @@ -362,6 +365,8 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, assert(excessivePos < mInputLength); assert(missingPos < mInputLength); } + mCorrectionState->setCorrectionParams(mProximityInfo, mInputLength, skipPos, excessivePos, + transposedPos); int rootPosition = ROOT_POS; // Get the number of children of root, then increment the position int childCount = Dictionary::getCount(DICT_ROOT, &rootPosition); @@ -389,8 +394,8 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, // depth will never be greater than maxDepth because in that case, // needsToTraverseChildrenNodes should be false const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex, - maxDepth, traverseAllNodes, matchWeight, inputIndex, diffs, skipPos, - excessivePos, transposedPos, nextLetters, nextLettersSize, &childCount, + maxDepth, traverseAllNodes, matchWeight, inputIndex, diffs, + nextLetters, nextLettersSize, mCorrectionState, &childCount, &firstChildPos, &traverseAllNodes, &matchWeight, &inputIndex, &diffs, &siblingPos, &outputIndex); // Update next sibling pos @@ -521,8 +526,12 @@ bool UnigramDictionary::getMistypedSpaceWords(const int inputLength, const int s } inline int UnigramDictionary::calculateFinalFreq(const int inputIndex, const int depth, - const int matchWeight, const int skipPos, const int excessivePos, const int transposedPos, - const int freq, const bool sameLength) const { + const int matchWeight, const int freq, const bool sameLength, + CorrectionState *correctionState) const { + const int skipPos = correctionState->getSkipPos(); + const int excessivePos = correctionState->getExcessivePos(); + const int transposedPos = correctionState->getTransposedPos(); + // TODO: Demote by edit distance int finalFreq = freq * matchWeight; if (skipPos >= 0) { @@ -587,16 +596,16 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c, inline void UnigramDictionary::onTerminal(unsigned short int* word, const int depth, const uint8_t* const root, const uint8_t flags, const int pos, - const int inputIndex, const int matchWeight, const int skipPos, - const int excessivePos, const int transposedPos, const int freq, const bool sameLength, - int* nextLetters, const int nextLettersSize) { + const int inputIndex, const int matchWeight, const int freq, const bool sameLength, + int* nextLetters, const int nextLettersSize, CorrectionState *correctionState) { + const int skipPos = correctionState->getSkipPos(); const bool isSameAsTyped = sameLength ? mProximityInfo->sameAsTyped(word, depth + 1) : false; if (isSameAsTyped) return; if (depth >= MIN_SUGGEST_DEPTH) { - const int finalFreq = calculateFinalFreq(inputIndex, depth, matchWeight, skipPos, - excessivePos, transposedPos, freq, sameLength); + const int finalFreq = calculateFinalFreq(inputIndex, depth, matchWeight, + freq, sameLength, correctionState); if (!isSameAsTyped) addWord(word, depth + 1, finalFreq); } @@ -648,48 +657,6 @@ bool UnigramDictionary::getSplitTwoWordsSuggestion(const int inputLength, } #ifndef NEW_DICTIONARY_FORMAT -// The following functions will be entirely replaced with new implementations. -void UnigramDictionary::getWordsOld(const int initialPos, const int inputLength, const int skipPos, - const int excessivePos, const int transposedPos,int *nextLetters, - const int nextLettersSize) { - int initialPosition = initialPos; - const int count = Dictionary::getCount(DICT_ROOT, &initialPosition); - getWordsRec(count, initialPosition, 0, - min(inputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH), - mInputLength <= 0, 1, 0, 0, skipPos, excessivePos, transposedPos, nextLetters, - nextLettersSize); -} - -void UnigramDictionary::getWordsRec(const int childrenCount, const int pos, const int depth, - const int maxDepth, const bool traverseAllNodes, const int matchWeight, - const int inputIndex, const int diffs, const int skipPos, const int excessivePos, - const int transposedPos, int *nextLetters, const int nextLettersSize) { - int siblingPos = pos; - for (int i = 0; i < childrenCount; ++i) { - int newCount; - int newChildPosition; - bool newTraverseAllNodes; - int newMatchRate; - int newInputIndex; - int newDiffs; - int newSiblingPos; - int newOutputIndex; - const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, depth, maxDepth, - traverseAllNodes, matchWeight, inputIndex, diffs, - skipPos, excessivePos, transposedPos, - nextLetters, nextLettersSize, - &newCount, &newChildPosition, &newTraverseAllNodes, &newMatchRate, - &newInputIndex, &newDiffs, &newSiblingPos, &newOutputIndex); - siblingPos = newSiblingPos; - - if (needsToTraverseChildrenNodes) { - getWordsRec(newCount, newChildPosition, newOutputIndex, maxDepth, newTraverseAllNodes, - newMatchRate, newInputIndex, newDiffs, skipPos, excessivePos, transposedPos, - nextLetters, nextLettersSize); - } - } -} - inline int UnigramDictionary::getMostFrequentWordLike(const int startInputIndex, const int inputLength, unsigned short *word) { int pos = ROOT_POS; @@ -829,10 +796,13 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs // The following functions will be modified. inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth, const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex, - const int initialDiffs, const int skipPos, const int excessivePos, const int transposedPos, - int *nextLetters, const int nextLettersSize, int *newCount, int *newChildPosition, + const int initialDiffs, int *nextLetters, const int nextLettersSize, + CorrectionState *correctionState, int *newCount, int *newChildPosition, bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs, int *nextSiblingPosition, int *nextOutputIndex) { + const int skipPos = correctionState->getSkipPos(); + const int excessivePos = correctionState->getExcessivePos(); + const int transposedPos = correctionState->getTransposedPos(); if (DEBUG_DICT) { int inputCount = 0; if (skipPos >= 0) ++inputCount; @@ -865,8 +835,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, depth)) { mWord[depth] = c; if (traverseAllNodes && terminal) { - onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos, - excessivePos, transposedPos, freq, false, nextLetters, nextLettersSize); + onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, + freq, false, nextLetters, nextLettersSize, mCorrectionState); } if (!needsToTraverseChildrenNodes) return false; *newTraverseAllNodes = traverseAllNodes; @@ -882,7 +852,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in } ProximityInfo::ProximityType matchedProximityCharId = mProximityInfo->getMatchedProximityId( - inputIndexForProximity, c, skipPos, excessivePos, transposedPos); + inputIndexForProximity, c, mCorrectionState); if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) return false; mWord[depth] = c; // If inputIndex is greater than mInputLength, that means there is no @@ -893,8 +863,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in bool isSameAsUserTypedLength = mInputLength == inputIndex + 1 || (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2); if (isSameAsUserTypedLength && terminal) { - onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos, - excessivePos, transposedPos, freq, true, nextLetters, nextLettersSize); + onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, + freq, true, nextLetters, nextLettersSize, mCorrectionState); } if (!needsToTraverseChildrenNodes) return false; // Start traversing all nodes after the index exceeds the user typed length @@ -1081,16 +1051,15 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs // given level, as output into newCount when traversing this level's parent. inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth, const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex, - const int initialDiffs, const int skipPos, const int excessivePos, const int transposedPos, - int *nextLetters, const int nextLettersSize, int *newCount, int *newChildrenPosition, + const int initialDiffs, int *nextLetters, const int nextLettersSize, + CorrectionState *correctionState, int *newCount, int *newChildrenPosition, bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs, int *nextSiblingPosition, int *newOutputIndex) { + const int skipPos = correctionState->getSkipPos(); + const int excessivePos = correctionState->getExcessivePos(); + const int transposedPos = correctionState->getTransposedPos(); if (DEBUG_DICT) { - int inputCount = 0; - if (skipPos >= 0) ++inputCount; - if (excessivePos >= 0) ++inputCount; - if (transposedPos >= 0) ++inputCount; - assert(inputCount <= 1); + correctionState->checkState(); } int pos = initialPos; int depth = initialDepth; @@ -1146,8 +1115,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // The frequency should be here, because we come here only if this is actually // a terminal node, and we are on its last char. const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos, - excessivePos, transposedPos, freq, false, nextLetters, nextLettersSize); + onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, + freq, false, nextLetters, nextLettersSize, mCorrectionState); } if (!hasChildren) { // If we don't have children here, that means we finished processing all @@ -1170,7 +1139,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in } int matchedProximityCharId = mProximityInfo->getMatchedProximityId( - inputIndexForProximity, c, skipPos, excessivePos, transposedPos); + inputIndexForProximity, c, mCorrectionState); if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) { // We found that this is an unrelated character, so we should give up traversing // this node and its children entirely. @@ -1197,8 +1166,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in || (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2); if (isSameAsUserTypedLength && isTerminal) { const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos, - excessivePos, transposedPos, freq, true, nextLetters, nextLettersSize); + onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, + freq, true, nextLetters, nextLettersSize, mCorrectionState); } // This character matched the typed character (enough to traverse the node at least) // so we just evaluated it. Now we should evaluate this virtual node's children - that diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h index 97198ef13..e8a0868ac 100644 --- a/native/src/unigram_dictionary.h +++ b/native/src/unigram_dictionary.h @@ -18,6 +18,7 @@ #define LATINIME_UNIGRAM_DICTIONARY_H #include +#include "correction_state.h" #include "defines.h" #include "proximity_info.h" @@ -76,7 +77,7 @@ public: int getSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, const int *ycoordinates, const int *codes, const int codesSize, const int flags, unsigned short *outWords, int *frequencies); - ~UnigramDictionary(); + virtual ~UnigramDictionary(); private: void getWordSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, @@ -99,34 +100,24 @@ private: const int secondWordStartPos, const int secondWordLength, const bool isSpaceProximity); bool getMissingSpaceWords(const int inputLength, const int missingSpacePos); bool getMistypedSpaceWords(const int inputLength, const int spaceProximityPos); - int calculateFinalFreq(const int inputIndex, const int depth, const int snr, const int skipPos, - const int excessivePos, const int transposedPos, const int freq, - const bool sameLength) const; + int calculateFinalFreq(const int inputIndex, const int depth, const int snr, + const int freq, const bool sameLength, CorrectionState *correctionState) const; void onTerminal(unsigned short int* word, const int depth, const uint8_t* const root, const uint8_t flags, const int pos, - const int inputIndex, const int matchWeight, const int skipPos, - const int excessivePos, const int transposedPos, const int freq, const bool sameLength, - int *nextLetters, const int nextLettersSize); + const int inputIndex, const int matchWeight, const int freq, const bool sameLength, + int* nextLetters, const int nextLettersSize, CorrectionState *correctionState); bool needsToSkipCurrentNode(const unsigned short c, const int inputIndex, const int skipPos, const int depth); // Process a node by considering proximity, missing and excessive character bool processCurrentNode(const int initialPos, const int initialDepth, - const int maxDepth, const bool initialTraverseAllNodes, const int snr, int inputIndex, - const int initialDiffs, const int skipPos, const int excessivePos, - const int transposedPos, int *nextLetters, const int nextLettersSize, int *newCount, - int *newChildPosition, bool *newTraverseAllNodes, int *newSnr, int*newInputIndex, - int *newDiffs, int *nextSiblingPosition, int *nextOutputIndex); + const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex, + const int initialDiffs, int *nextLetters, const int nextLettersSize, + CorrectionState *correctionState, int *newCount, int *newChildPosition, + bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs, + int *nextSiblingPosition, int *nextOutputIndex); int getMostFrequentWordLike(const int startInputIndex, const int inputLength, unsigned short *word); #ifndef NEW_DICTIONARY_FORMAT - void getWordsRec(const int childrenCount, const int pos, const int depth, const int maxDepth, - const bool traverseAllNodes, const int snr, const int inputIndex, const int diffs, - const int skipPos, const int excessivePos, const int transposedPos, int *nextLetters, - const int nextLettersSize); - // Keep getWordsOld for comparing performance between getWords and getWordsOld - void getWordsOld(const int initialPos, const int inputLength, const int skipPos, - const int excessivePos, const int transposedPos, int *nextLetters, - const int nextLettersSize); // Process a node by considering missing space bool processCurrentNodeForExactMatch(const int firstChildPos, const int startInputIndex, const int depth, unsigned short *word, @@ -158,7 +149,8 @@ private: int *mFrequencies; unsigned short *mOutputChars; - const ProximityInfo *mProximityInfo; + ProximityInfo *mProximityInfo; + CorrectionState *mCorrectionState; int mInputLength; // MAX_WORD_LENGTH_INTERNAL must be bigger than MAX_WORD_LENGTH unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];