diff --git a/native/src/correction.cpp b/native/src/correction.cpp index 6d682c0c9..a931a61fb 100644 --- a/native/src/correction.cpp +++ b/native/src/correction.cpp @@ -52,6 +52,11 @@ void Correction::initCorrection(const ProximityInfo *pi, const int inputLength, mSkippedOutputIndex = -1; } +void Correction::initCorrectionState( + const int rootPos, const int childCount, const bool traverseAll) { + mCorrectionStates[0].init(rootPos, childCount, traverseAll); +} + void Correction::setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, const int spaceProximityPos, const int missingSpacePos) { mSkipPos = skipPos; @@ -90,22 +95,25 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen inputIndex, outputIndex, mMatchedCharCount, freq, sameLength, this); } -void Correction::initProcessState(const int matchCount, const int inputIndex, - const int outputIndex, const bool traverseAllNodes, const int diffs) { - mMatchedCharCount = matchCount; - mInputIndex = inputIndex; +bool Correction::initProcessState(const int outputIndex) { + if (mCorrectionStates[outputIndex].mChildCount <= 0) { + return false; + } mOutputIndex = outputIndex; - mTraverseAllNodes = traverseAllNodes; - mDiffs = diffs; + --(mCorrectionStates[outputIndex].mChildCount); + mMatchedCharCount = mCorrectionStates[outputIndex].mMatchedCount; + mInputIndex = mCorrectionStates[outputIndex].mInputIndex; + mTraverseAllNodes = mCorrectionStates[outputIndex].mTraverseAll; + mDiffs = mCorrectionStates[outputIndex].mDiffs; + return true; } -void Correction::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex, - bool *traverseAllNodes, int *diffs) { - *matchedCount = mMatchedCharCount; - *inputIndex = mInputIndex; - *outputIndex = mOutputIndex; - *traverseAllNodes = mTraverseAllNodes; - *diffs = mDiffs; +int Correction::goDownTree( + const int parentIndex, const int childCount, const int firstChildPos) { + mCorrectionStates[mOutputIndex].mParentIndex = parentIndex; + mCorrectionStates[mOutputIndex].mChildCount = childCount; + mCorrectionStates[mOutputIndex].mSiblingPos = firstChildPos; + return mOutputIndex; } void Correction::charMatched() { @@ -133,6 +141,13 @@ void Correction::incrementInputIndex() { void Correction::incrementOutputIndex() { ++mOutputIndex; + mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex; + mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount; + mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; + mCorrectionStates[mOutputIndex].mMatchedCount = mMatchedCharCount; + mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; + mCorrectionStates[mOutputIndex].mTraverseAll = mTraverseAllNodes; + mCorrectionStates[mOutputIndex].mDiffs = mDiffs; } void Correction::startTraverseAll() { diff --git a/native/src/correction.h b/native/src/correction.h index ae6c7a421..590d62f62 100644 --- a/native/src/correction.h +++ b/native/src/correction.h @@ -18,6 +18,7 @@ #define LATINIME_CORRECTION_H #include +#include "correction_state.h" #include "defines.h" @@ -39,11 +40,14 @@ public: Correction(const int typedLetterMultiplier, const int fullWordMultiplier); void initCorrection( const ProximityInfo *pi, const int inputLength, const int maxWordLength); + void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll); + + // TODO: remove void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, const int spaceProximityPos, const int missingSpacePos); void checkState(); - void initProcessState(const int matchCount, const int inputIndex, const int outputIndex, - const bool traverseAllNodes, const int diffs); + bool initProcessState(const int index); + void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex, bool *traverseAllNodes, int *diffs); int getOutputIndex(); @@ -80,6 +84,22 @@ public: int getDiffs() const { return mDiffs; } + + ///////////////////////// + // Tree helper methods + int goDownTree(const int parentIndex, const int childCount, const int firstChildPos); + + inline int getTreeSiblingPos(const int index) const { + return mCorrectionStates[index].mSiblingPos; + } + + inline void setTreeSiblingPos(const int index, const int pos) { + mCorrectionStates[index].mSiblingPos = pos; + } + + inline int getTreeParentIndex(const int index) const { + return mCorrectionStates[index].mParentIndex; + } private: void charMatched(); void incrementInputIndex(); @@ -116,6 +136,8 @@ private: bool mTraverseAllNodes; unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; + CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL]; + inline bool isQuote(const unsigned short c); inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal); diff --git a/native/src/correction_state.h b/native/src/correction_state.h new file mode 100644 index 000000000..1fe02b853 --- /dev/null +++ b/native/src/correction_state.h @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2011 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_CORRECTION_STATE_H +#define LATINIME_CORRECTION_STATE_H + +#include + +#include "defines.h" + +namespace latinime { + +class CorrectionState { +public: + int mParentIndex; + int mMatchedCount; + int mChildCount; + int mInputIndex; + int mDiffs; + int mSiblingPos; + bool mTraverseAll; + + inline void init(const int rootPos, const int childCount, const bool traverseAll) { + set(-1, 0, childCount, 0, 0, rootPos, traverseAll); + } + +private: + inline void set(const int parentIndex, const int matchedCount, const int childCount, + const int inputIndex, const int diffs, const int siblingPos, + const bool traverseAll) { + mParentIndex = parentIndex; + mMatchedCount = matchedCount; + mChildCount = childCount; + mInputIndex = inputIndex; + mDiffs = diffs; + mSiblingPos = siblingPos; + mTraverseAll = traverseAll; + } +}; +} // namespace latinime +#endif // LATINIME_CORRECTION_STATE_H diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index df1a2e273..bbfaea454 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -352,44 +352,28 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, int rootPosition = ROOT_POS; // Get the number of children of root, then increment the position int childCount = Dictionary::getCount(DICT_ROOT, &rootPosition); - int depth = 0; + int outputIndex = 0; - mStackChildCount[0] = childCount; - mStackTraverseAll[0] = (mInputLength <= 0); - mStackInputIndex[0] = 0; - mStackDiffs[0] = 0; - mStackSiblingPos[0] = rootPosition; - mStackOutputIndex[0] = 0; - mStackMatchedCount[0] = 0; + mCorrection->initCorrectionState(rootPosition, childCount, (mInputLength <= 0)); // Depth first search - while (depth >= 0) { - if (mStackChildCount[depth] > 0) { - --mStackChildCount[depth]; - int siblingPos = mStackSiblingPos[depth]; + while (outputIndex >= 0) { + if (mCorrection->initProcessState(outputIndex)) { + int siblingPos = mCorrection->getTreeSiblingPos(outputIndex); int firstChildPos; - mCorrection->initProcessState( - mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth], - mStackTraverseAll[depth], mStackDiffs[depth]); - // needsToTraverseChildrenNodes should be false const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, mCorrection, &childCount, &firstChildPos, &siblingPos); // Update next sibling pos - mStackSiblingPos[depth] = siblingPos; + mCorrection->setTreeSiblingPos(outputIndex, siblingPos); + if (needsToTraverseChildrenNodes) { // Goes to child node - ++depth; - mStackChildCount[depth] = childCount; - mStackSiblingPos[depth] = firstChildPos; - - mCorrection->getProcessState(&mStackMatchedCount[depth], - &mStackInputIndex[depth], &mStackOutputIndex[depth], - &mStackTraverseAll[depth], &mStackDiffs[depth]); + outputIndex = mCorrection->goDownTree(outputIndex, childCount, firstChildPos); } } else { // Goes to parent sibling node - --depth; + outputIndex = mCorrection->getTreeParentIndex(outputIndex); } } } diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h index 8bcd7cea5..cfe63ff79 100644 --- a/native/src/unigram_dictionary.h +++ b/native/src/unigram_dictionary.h @@ -19,6 +19,7 @@ #include #include "correction.h" +#include "correction_state.h" #include "defines.h" #include "proximity_info.h" @@ -134,13 +135,9 @@ private: // MAX_WORD_LENGTH_INTERNAL must be bigger than MAX_WORD_LENGTH unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; - int mStackMatchedCount[MAX_WORD_LENGTH_INTERNAL]; - int mStackChildCount[MAX_WORD_LENGTH_INTERNAL]; - bool mStackTraverseAll[MAX_WORD_LENGTH_INTERNAL]; - int mStackInputIndex[MAX_WORD_LENGTH_INTERNAL]; - int mStackDiffs[MAX_WORD_LENGTH_INTERNAL]; - int mStackSiblingPos[MAX_WORD_LENGTH_INTERNAL]; - int mStackOutputIndex[MAX_WORD_LENGTH_INTERNAL]; + int mStackChildCount[MAX_WORD_LENGTH_INTERNAL];// TODO: remove + int mStackInputIndex[MAX_WORD_LENGTH_INTERNAL];// TODO: remove + int mStackSiblingPos[MAX_WORD_LENGTH_INTERNAL];// TODO: remove }; } // namespace latinime