Add correction state

Change-Id: I0d281cede1590893bd1def005cf83c9431d12750
This commit is contained in:
satok 2011-07-15 13:49:00 +09:00
parent e00d44d0c8
commit 2df3060883
7 changed files with 166 additions and 97 deletions

View file

@ -17,6 +17,7 @@ LOCAL_SRC_FILES := \
jni/jni_common.cpp \
src/bigram_dictionary.cpp \
src/char_utils.cpp \
src/correction_state.cpp \
src/dictionary.cpp \
src/proximity_info.cpp \
src/unigram_dictionary.cpp

View file

@ -0,0 +1,52 @@
/*
* Copyright (C) 2011 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include <stdio.h>
#include <string.h>
#define LOG_TAG "LatinIME: correction_state.cpp"
#include "correction_state.h"
namespace latinime {
CorrectionState::CorrectionState() {
}
void CorrectionState::setCorrectionParams(const ProximityInfo *pi, const int inputLength,
const int skipPos, const int excessivePos, const int transposedPos) {
mProximityInfo = pi;
mSkipPos = skipPos;
mExcessivePos = excessivePos;
mTransposedPos = transposedPos;
}
void CorrectionState::checkState() {
if (DEBUG_DICT) {
int inputCount = 0;
if (mSkipPos >= 0) ++inputCount;
if (mExcessivePos >= 0) ++inputCount;
if (mTransposedPos >= 0) ++inputCount;
// TODO: remove this assert
assert(inputCount <= 1);
}
}
CorrectionState::~CorrectionState() {
}
} // namespace latinime

View file

@ -0,0 +1,52 @@
/*
* Copyright (C) 2011 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_CORRECTION_STATE_H
#define LATINIME_CORRECTION_STATE_H
#include <stdint.h>
#include "defines.h"
namespace latinime {
class ProximityInfo;
class CorrectionState {
public:
CorrectionState();
void setCorrectionParams(const ProximityInfo *pi, const int inputLength, const int skipPos,
const int excessivePos, const int transposedPos);
void checkState();
virtual ~CorrectionState();
int getSkipPos() const {
return mSkipPos;
}
int getExcessivePos() const {
return mExcessivePos;
}
int getTransposedPos() const {
return mTransposedPos;
}
private:
const ProximityInfo *mProximityInfo;
int mInputLength;
int mSkipPos;
int mExcessivePos;
int mTransposedPos;
};
} // namespace latinime
#endif // LATINIME_CORRECTION_INFO_H

View file

@ -78,7 +78,7 @@ unsigned short ProximityInfo::getPrimaryCharAt(const int index) const {
return getProximityCharsAt(index)[0];
}
bool ProximityInfo::existsCharInProximityAt(const int index, const int c) const {
inline bool ProximityInfo::existsCharInProximityAt(const int index, const int c) const {
const int *chars = getProximityCharsAt(index);
int i = 0;
while (chars[i] > 0 && i < MAX_PROXIMITY_CHARS_SIZE) {
@ -114,8 +114,10 @@ bool ProximityInfo::existsAdjacentProximityChars(const int index) const {
// in their list. The non-accented version of the character should be considered
// "close", but not the other keys close to the non-accented version.
ProximityInfo::ProximityType ProximityInfo::getMatchedProximityId(
const int index, const unsigned short c, const int skipPos,
const int excessivePos, const int transposedPos) const {
const int index, const unsigned short c, CorrectionState *correctionState) const {
const int skipPos = correctionState->getSkipPos();
const int excessivePos = correctionState->getExcessivePos();
const int transposedPos = correctionState->getTransposedPos();
const int *currentChars = getProximityCharsAt(index);
const unsigned short baseLowerC = Dictionary::toBaseLowerCase(c);

View file

@ -23,6 +23,8 @@
namespace latinime {
class CorrectionState;
class ProximityInfo {
public:
typedef enum { // Used as a return value for character comparison
@ -42,8 +44,7 @@ public:
bool existsCharInProximityAt(const int index, const int c) const;
bool existsAdjacentProximityChars(const int index) const;
ProximityType getMatchedProximityId(
const int index, const unsigned short c, const int skipPos,
const int excessivePos, const int transposedPos) const;
const int index, const unsigned short c, CorrectionState *correctionState) const;
bool sameAsTyped(const unsigned short *word, int length) const;
private:
int getStartIndexFromCoordinates(const int x, const int y) const;

View file

@ -58,9 +58,12 @@ UnigramDictionary::UnigramDictionary(const uint8_t* const streamStart, int typed
if (DEBUG_DICT) {
LOGI("UnigramDictionary - constructor");
}
mCorrectionState = new CorrectionState();
}
UnigramDictionary::~UnigramDictionary() {}
UnigramDictionary::~UnigramDictionary() {
delete mCorrectionState;
}
static inline unsigned int getCodesBufferSize(const int* codes, const int codesSize,
const int MAX_PROXIMITY_CHARS) {
@ -362,6 +365,8 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos,
assert(excessivePos < mInputLength);
assert(missingPos < mInputLength);
}
mCorrectionState->setCorrectionParams(mProximityInfo, mInputLength, skipPos, excessivePos,
transposedPos);
int rootPosition = ROOT_POS;
// Get the number of children of root, then increment the position
int childCount = Dictionary::getCount(DICT_ROOT, &rootPosition);
@ -389,8 +394,8 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos,
// depth will never be greater than maxDepth because in that case,
// needsToTraverseChildrenNodes should be false
const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex,
maxDepth, traverseAllNodes, matchWeight, inputIndex, diffs, skipPos,
excessivePos, transposedPos, nextLetters, nextLettersSize, &childCount,
maxDepth, traverseAllNodes, matchWeight, inputIndex, diffs,
nextLetters, nextLettersSize, mCorrectionState, &childCount,
&firstChildPos, &traverseAllNodes, &matchWeight, &inputIndex, &diffs,
&siblingPos, &outputIndex);
// Update next sibling pos
@ -521,8 +526,12 @@ bool UnigramDictionary::getMistypedSpaceWords(const int inputLength, const int s
}
inline int UnigramDictionary::calculateFinalFreq(const int inputIndex, const int depth,
const int matchWeight, const int skipPos, const int excessivePos, const int transposedPos,
const int freq, const bool sameLength) const {
const int matchWeight, const int freq, const bool sameLength,
CorrectionState *correctionState) const {
const int skipPos = correctionState->getSkipPos();
const int excessivePos = correctionState->getExcessivePos();
const int transposedPos = correctionState->getTransposedPos();
// TODO: Demote by edit distance
int finalFreq = freq * matchWeight;
if (skipPos >= 0) {
@ -587,16 +596,16 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c,
inline void UnigramDictionary::onTerminal(unsigned short int* word, const int depth,
const uint8_t* const root, const uint8_t flags, const int pos,
const int inputIndex, const int matchWeight, const int skipPos,
const int excessivePos, const int transposedPos, const int freq, const bool sameLength,
int* nextLetters, const int nextLettersSize) {
const int inputIndex, const int matchWeight, const int freq, const bool sameLength,
int* nextLetters, const int nextLettersSize, CorrectionState *correctionState) {
const int skipPos = correctionState->getSkipPos();
const bool isSameAsTyped = sameLength ? mProximityInfo->sameAsTyped(word, depth + 1) : false;
if (isSameAsTyped) return;
if (depth >= MIN_SUGGEST_DEPTH) {
const int finalFreq = calculateFinalFreq(inputIndex, depth, matchWeight, skipPos,
excessivePos, transposedPos, freq, sameLength);
const int finalFreq = calculateFinalFreq(inputIndex, depth, matchWeight,
freq, sameLength, correctionState);
if (!isSameAsTyped)
addWord(word, depth + 1, finalFreq);
}
@ -648,48 +657,6 @@ bool UnigramDictionary::getSplitTwoWordsSuggestion(const int inputLength,
}
#ifndef NEW_DICTIONARY_FORMAT
// The following functions will be entirely replaced with new implementations.
void UnigramDictionary::getWordsOld(const int initialPos, const int inputLength, const int skipPos,
const int excessivePos, const int transposedPos,int *nextLetters,
const int nextLettersSize) {
int initialPosition = initialPos;
const int count = Dictionary::getCount(DICT_ROOT, &initialPosition);
getWordsRec(count, initialPosition, 0,
min(inputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH),
mInputLength <= 0, 1, 0, 0, skipPos, excessivePos, transposedPos, nextLetters,
nextLettersSize);
}
void UnigramDictionary::getWordsRec(const int childrenCount, const int pos, const int depth,
const int maxDepth, const bool traverseAllNodes, const int matchWeight,
const int inputIndex, const int diffs, const int skipPos, const int excessivePos,
const int transposedPos, int *nextLetters, const int nextLettersSize) {
int siblingPos = pos;
for (int i = 0; i < childrenCount; ++i) {
int newCount;
int newChildPosition;
bool newTraverseAllNodes;
int newMatchRate;
int newInputIndex;
int newDiffs;
int newSiblingPos;
int newOutputIndex;
const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, depth, maxDepth,
traverseAllNodes, matchWeight, inputIndex, diffs,
skipPos, excessivePos, transposedPos,
nextLetters, nextLettersSize,
&newCount, &newChildPosition, &newTraverseAllNodes, &newMatchRate,
&newInputIndex, &newDiffs, &newSiblingPos, &newOutputIndex);
siblingPos = newSiblingPos;
if (needsToTraverseChildrenNodes) {
getWordsRec(newCount, newChildPosition, newOutputIndex, maxDepth, newTraverseAllNodes,
newMatchRate, newInputIndex, newDiffs, skipPos, excessivePos, transposedPos,
nextLetters, nextLettersSize);
}
}
}
inline int UnigramDictionary::getMostFrequentWordLike(const int startInputIndex,
const int inputLength, unsigned short *word) {
int pos = ROOT_POS;
@ -829,10 +796,13 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs
// The following functions will be modified.
inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth,
const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex,
const int initialDiffs, const int skipPos, const int excessivePos, const int transposedPos,
int *nextLetters, const int nextLettersSize, int *newCount, int *newChildPosition,
const int initialDiffs, int *nextLetters, const int nextLettersSize,
CorrectionState *correctionState, int *newCount, int *newChildPosition,
bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs,
int *nextSiblingPosition, int *nextOutputIndex) {
const int skipPos = correctionState->getSkipPos();
const int excessivePos = correctionState->getExcessivePos();
const int transposedPos = correctionState->getTransposedPos();
if (DEBUG_DICT) {
int inputCount = 0;
if (skipPos >= 0) ++inputCount;
@ -865,8 +835,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, depth)) {
mWord[depth] = c;
if (traverseAllNodes && terminal) {
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos,
excessivePos, transposedPos, freq, false, nextLetters, nextLettersSize);
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight,
freq, false, nextLetters, nextLettersSize, mCorrectionState);
}
if (!needsToTraverseChildrenNodes) return false;
*newTraverseAllNodes = traverseAllNodes;
@ -882,7 +852,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
}
ProximityInfo::ProximityType matchedProximityCharId = mProximityInfo->getMatchedProximityId(
inputIndexForProximity, c, skipPos, excessivePos, transposedPos);
inputIndexForProximity, c, mCorrectionState);
if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) return false;
mWord[depth] = c;
// If inputIndex is greater than mInputLength, that means there is no
@ -893,8 +863,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
bool isSameAsUserTypedLength = mInputLength == inputIndex + 1
|| (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2);
if (isSameAsUserTypedLength && terminal) {
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos,
excessivePos, transposedPos, freq, true, nextLetters, nextLettersSize);
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight,
freq, true, nextLetters, nextLettersSize, mCorrectionState);
}
if (!needsToTraverseChildrenNodes) return false;
// Start traversing all nodes after the index exceeds the user typed length
@ -1081,16 +1051,15 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs
// given level, as output into newCount when traversing this level's parent.
inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth,
const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex,
const int initialDiffs, const int skipPos, const int excessivePos, const int transposedPos,
int *nextLetters, const int nextLettersSize, int *newCount, int *newChildrenPosition,
const int initialDiffs, int *nextLetters, const int nextLettersSize,
CorrectionState *correctionState, int *newCount, int *newChildrenPosition,
bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs,
int *nextSiblingPosition, int *newOutputIndex) {
const int skipPos = correctionState->getSkipPos();
const int excessivePos = correctionState->getExcessivePos();
const int transposedPos = correctionState->getTransposedPos();
if (DEBUG_DICT) {
int inputCount = 0;
if (skipPos >= 0) ++inputCount;
if (excessivePos >= 0) ++inputCount;
if (transposedPos >= 0) ++inputCount;
assert(inputCount <= 1);
correctionState->checkState();
}
int pos = initialPos;
int depth = initialDepth;
@ -1146,8 +1115,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
// The frequency should be here, because we come here only if this is actually
// a terminal node, and we are on its last char.
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos,
excessivePos, transposedPos, freq, false, nextLetters, nextLettersSize);
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight,
freq, false, nextLetters, nextLettersSize, mCorrectionState);
}
if (!hasChildren) {
// If we don't have children here, that means we finished processing all
@ -1170,7 +1139,7 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
}
int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
inputIndexForProximity, c, skipPos, excessivePos, transposedPos);
inputIndexForProximity, c, mCorrectionState);
if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
// We found that this is an unrelated character, so we should give up traversing
// this node and its children entirely.
@ -1197,8 +1166,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
|| (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2);
if (isSameAsUserTypedLength && isTerminal) {
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos,
excessivePos, transposedPos, freq, true, nextLetters, nextLettersSize);
onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight,
freq, true, nextLetters, nextLettersSize, mCorrectionState);
}
// This character matched the typed character (enough to traverse the node at least)
// so we just evaluated it. Now we should evaluate this virtual node's children - that

View file

@ -18,6 +18,7 @@
#define LATINIME_UNIGRAM_DICTIONARY_H
#include <stdint.h>
#include "correction_state.h"
#include "defines.h"
#include "proximity_info.h"
@ -76,7 +77,7 @@ public:
int getSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates,
const int *ycoordinates, const int *codes, const int codesSize, const int flags,
unsigned short *outWords, int *frequencies);
~UnigramDictionary();
virtual ~UnigramDictionary();
private:
void getWordSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates,
@ -99,34 +100,24 @@ private:
const int secondWordStartPos, const int secondWordLength, const bool isSpaceProximity);
bool getMissingSpaceWords(const int inputLength, const int missingSpacePos);
bool getMistypedSpaceWords(const int inputLength, const int spaceProximityPos);
int calculateFinalFreq(const int inputIndex, const int depth, const int snr, const int skipPos,
const int excessivePos, const int transposedPos, const int freq,
const bool sameLength) const;
int calculateFinalFreq(const int inputIndex, const int depth, const int snr,
const int freq, const bool sameLength, CorrectionState *correctionState) const;
void onTerminal(unsigned short int* word, const int depth,
const uint8_t* const root, const uint8_t flags, const int pos,
const int inputIndex, const int matchWeight, const int skipPos,
const int excessivePos, const int transposedPos, const int freq, const bool sameLength,
int *nextLetters, const int nextLettersSize);
const int inputIndex, const int matchWeight, const int freq, const bool sameLength,
int* nextLetters, const int nextLettersSize, CorrectionState *correctionState);
bool needsToSkipCurrentNode(const unsigned short c,
const int inputIndex, const int skipPos, const int depth);
// Process a node by considering proximity, missing and excessive character
bool processCurrentNode(const int initialPos, const int initialDepth,
const int maxDepth, const bool initialTraverseAllNodes, const int snr, int inputIndex,
const int initialDiffs, const int skipPos, const int excessivePos,
const int transposedPos, int *nextLetters, const int nextLettersSize, int *newCount,
int *newChildPosition, bool *newTraverseAllNodes, int *newSnr, int*newInputIndex,
int *newDiffs, int *nextSiblingPosition, int *nextOutputIndex);
const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex,
const int initialDiffs, int *nextLetters, const int nextLettersSize,
CorrectionState *correctionState, int *newCount, int *newChildPosition,
bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs,
int *nextSiblingPosition, int *nextOutputIndex);
int getMostFrequentWordLike(const int startInputIndex, const int inputLength,
unsigned short *word);
#ifndef NEW_DICTIONARY_FORMAT
void getWordsRec(const int childrenCount, const int pos, const int depth, const int maxDepth,
const bool traverseAllNodes, const int snr, const int inputIndex, const int diffs,
const int skipPos, const int excessivePos, const int transposedPos, int *nextLetters,
const int nextLettersSize);
// Keep getWordsOld for comparing performance between getWords and getWordsOld
void getWordsOld(const int initialPos, const int inputLength, const int skipPos,
const int excessivePos, const int transposedPos, int *nextLetters,
const int nextLettersSize);
// Process a node by considering missing space
bool processCurrentNodeForExactMatch(const int firstChildPos,
const int startInputIndex, const int depth, unsigned short *word,
@ -158,7 +149,8 @@ private:
int *mFrequencies;
unsigned short *mOutputChars;
const ProximityInfo *mProximityInfo;
ProximityInfo *mProximityInfo;
CorrectionState *mCorrectionState;
int mInputLength;
// MAX_WORD_LENGTH_INTERNAL must be bigger than MAX_WORD_LENGTH
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];