am 3261746f: Merge "Fix exact match checking for words with digraph."

* commit '3261746ff99dbf37d556b9c0e82192068437f99c':
  Fix exact match checking for words with digraph.
This commit is contained in:
Keisuke Kuroynagi 2013-04-25 02:59:48 -07:00 committed by Android Git Automerger
commit f2edc2a81d
9 changed files with 115 additions and 102 deletions

View file

@ -438,4 +438,24 @@ typedef enum {
// Create new word with space substitution // Create new word with space substitution
CT_NEW_WORD_SPACE_SUBSTITUTION, CT_NEW_WORD_SPACE_SUBSTITUTION,
} CorrectionType; } CorrectionType;
// ErrorType is mainly decided by CorrectionType but it is also depending on if
// the correction has really been performed or not.
typedef enum {
// Substitution, omission and transposition
ET_EDIT_CORRECTION,
// Proximity error
ET_PROXIMITY_CORRECTION,
// Completion
ET_COMPLETION,
// New word
// TODO: Remove.
// A new word error should be an edit correction error or a proximity correction error.
ET_NEW_WORD,
// Treat error as an intentional omission when the CorrectionType is omission and the node can
// be intentional omission.
ET_INTENTIONAL_OMISSION,
// Not treated as an error. Tracked for checking exact match
ET_NOT_AN_ERROR
} ErrorType;
#endif // LATINIME_DEFINES_H #endif // LATINIME_DEFINES_H

View file

@ -463,6 +463,10 @@ class DicNode {
mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
} }
bool isExactMatch() const {
return mDicNodeState.mDicNodeStateScoring.isExactMatch();
}
uint8_t getFlags() const { uint8_t getFlags() const {
return mDicNodeProperties.getFlags(); return mDicNodeProperties.getFlags();
} }
@ -542,13 +546,12 @@ class DicNode {
// Caveat: Must not be called outside Weighting // Caveat: Must not be called outside Weighting
// This restriction is guaranteed by "friend" // This restriction is guaranteed by "friend"
AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
const bool doNormalization, const int inputSize, const bool isEditCorrection, const bool doNormalization, const int inputSize, const ErrorType errorType) {
const bool isProximityCorrection) {
if (DEBUG_GEO_FULL) { if (DEBUG_GEO_FULL) {
LOGI_SHOW_ADD_COST_PROP; LOGI_SHOW_ADD_COST_PROP;
} }
mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection); inputSize, getTotalInputIndex(), errorType);
} }
// Caveat: Must not be called outside Weighting // Caveat: Must not be called outside Weighting

View file

@ -46,8 +46,8 @@ class DicNodeStateInput {
for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
mInputIndex[i] = src->mInputIndex[i]; mInputIndex[i] = src->mInputIndex[i];
mPrevCodePoint[i] = src->mPrevCodePoint[i]; mPrevCodePoint[i] = src->mPrevCodePoint[i];
mTerminalDiffCost[i] = resetTerminalDiffCost ? mTerminalDiffCost[i] = resetTerminalDiffCost ?
static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i]; static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i];
} }
} }

View file

@ -31,7 +31,7 @@ class DicNodeStateScoring {
mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
mEditCorrectionCount(0), mProximityCorrectionCount(0), mEditCorrectionCount(0), mProximityCorrectionCount(0),
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
mRawLength(0.0f) { mRawLength(0.0f), mExactMatch(true) {
} }
virtual ~DicNodeStateScoring() {} virtual ~DicNodeStateScoring() {}
@ -45,6 +45,7 @@ class DicNodeStateScoring {
mRawLength = 0.0f; mRawLength = 0.0f;
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
mExactMatch = true;
} }
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) { AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
@ -56,17 +57,32 @@ class DicNodeStateScoring {
mRawLength = scoring->mRawLength; mRawLength = scoring->mRawLength;
mDoubleLetterLevel = scoring->mDoubleLetterLevel; mDoubleLetterLevel = scoring->mDoubleLetterLevel;
mDigraphIndex = scoring->mDigraphIndex; mDigraphIndex = scoring->mDigraphIndex;
mExactMatch = scoring->mExactMatch;
} }
void addCost(const float spatialCost, const float languageCost, const bool doNormalization, void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
const int inputSize, const int totalInputIndex, const bool isEditCorrection, const int inputSize, const int totalInputIndex, const ErrorType errorType) {
const bool isProximityCorrection) {
addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex); addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
if (isEditCorrection) { switch (errorType) {
++mEditCorrectionCount; case ET_EDIT_CORRECTION:
} ++mEditCorrectionCount;
if (isProximityCorrection) { mExactMatch = false;
++mProximityCorrectionCount; break;
case ET_PROXIMITY_CORRECTION:
++mProximityCorrectionCount;
mExactMatch = false;
break;
case ET_COMPLETION:
mExactMatch = false;
break;
case ET_NEW_WORD:
mExactMatch = false;
break;
case ET_INTENTIONAL_OMISSION:
mExactMatch = false;
break;
case ET_NOT_AN_ERROR:
break;
} }
} }
@ -143,6 +159,10 @@ class DicNodeStateScoring {
} }
} }
bool isExactMatch() const {
return mExactMatch;
}
private: private:
// Caution!!! // Caution!!!
// Use a default copy constructor and an assign operator because shallow copies are ok // Use a default copy constructor and an assign operator because shallow copies are ok
@ -157,6 +177,7 @@ class DicNodeStateScoring {
float mSpatialDistance; float mSpatialDistance;
float mLanguageDistance; float mLanguageDistance;
float mRawLength; float mRawLength;
bool mExactMatch;
AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
bool doNormalization, int inputSize, int totalInputIndex) { bool doNormalization, int inputSize, int totalInputIndex) {

View file

@ -80,9 +80,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
traverseSession, parentDicNode, dicNode, &inputStateG); traverseSession, parentDicNode, dicNode, &inputStateG);
const float languageCost = Weighting::getLanguageCost(weighting, correctionType, const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
traverseSession, parentDicNode, dicNode, bigramCacheMap); traverseSession, parentDicNode, dicNode, bigramCacheMap);
const bool edit = Weighting::isEditCorrection(correctionType); const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession,
const bool proximity = Weighting::isProximityCorrection(weighting, correctionType, parentDicNode, dicNode);
traverseSession, dicNode);
profile(correctionType, dicNode); profile(correctionType, dicNode);
if (inputStateG.mNeedsToUpdateInputStateG) { if (inputStateG.mNeedsToUpdateInputStateG) {
dicNode->updateInputIndexG(&inputStateG); dicNode->updateInputIndexG(&inputStateG);
@ -91,7 +90,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
(correctionType == CT_TRANSPOSITION)); (correctionType == CT_TRANSPOSITION));
} }
dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(), dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
inputSize, edit, proximity); inputSize, errorType);
} }
/* static */ float Weighting::getSpatialCost(const Weighting *const weighting, /* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
@ -158,62 +157,6 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
} }
} }
/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) {
switch(correctionType) {
case CT_OMISSION:
return true;
case CT_ADDITIONAL_PROXIMITY:
return true;
case CT_SUBSTITUTION:
return true;
case CT_NEW_WORD_SPACE_OMITTION:
return false;
case CT_MATCH:
return false;
case CT_COMPLETION:
return false;
case CT_TERMINAL:
return false;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
return false;
case CT_INSERTION:
return true;
case CT_TRANSPOSITION:
return true;
default:
return false;
}
}
/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting,
const CorrectionType correctionType,
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) {
switch(correctionType) {
case CT_OMISSION:
return false;
case CT_ADDITIONAL_PROXIMITY:
return true;
case CT_SUBSTITUTION:
return false;
case CT_NEW_WORD_SPACE_OMITTION:
return false;
case CT_MATCH:
return weighting->isProximityDicNode(traverseSession, dicNode);
case CT_COMPLETION:
return false;
case CT_TERMINAL:
return false;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
return false;
case CT_INSERTION:
return false;
case CT_TRANSPOSITION:
return false;
default:
return false;
}
}
/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) { /* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
switch(correctionType) { switch(correctionType) {
case CT_OMISSION: case CT_OMISSION:

View file

@ -80,6 +80,10 @@ class Weighting {
virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0; const DicNode *const dicNode) const = 0;
virtual ErrorType getErrorType(const CorrectionType correctionType,
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
Weighting() {} Weighting() {}
virtual ~Weighting() {} virtual ~Weighting() {}
@ -95,12 +99,6 @@ class Weighting {
const DicNode *const parentDicNode, const DicNode *const dicNode, const DicNode *const parentDicNode, const DicNode *const dicNode,
hash_map_compat<int, int16_t> *const bigramCacheMap); hash_map_compat<int, int16_t> *const bigramCacheMap);
// TODO: Move to TypingWeighting and GestureWeighting? // TODO: Move to TypingWeighting and GestureWeighting?
static bool isEditCorrection(const CorrectionType correctionType);
// TODO: Move to TypingWeighting and GestureWeighting?
static bool isProximityCorrection(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const dicNode);
// TODO: Move to TypingWeighting and GestureWeighting?
static int getForwardInputCount(const CorrectionType correctionType); static int getForwardInputCount(const CorrectionType correctionType);
}; };
} // namespace latinime } // namespace latinime

View file

@ -422,20 +422,15 @@ void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession,
*/ */
void Suggest::processDicNodeAsOmission( void Suggest::processDicNodeAsOmission(
DicTraverseSession *traverseSession, DicNode *dicNode) const { DicTraverseSession *traverseSession, DicNode *dicNode) const {
// If the omission is surely intentional that it should incur zero cost.
const bool isZeroCostOmission = dicNode->isZeroCostOmission();
DicNodeVector childDicNodes; DicNodeVector childDicNodes;
DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes); DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes);
const int size = childDicNodes.getSizeAndLock(); const int size = childDicNodes.getSizeAndLock();
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
DicNode *const childDicNode = childDicNodes[i]; DicNode *const childDicNode = childDicNodes[i];
if (!isZeroCostOmission) { // Treat this word as omission
// Treat this word as omission Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, dicNode, childDicNode, 0 /* bigramCacheMap */);
dicNode, childDicNode, 0 /* bigramCacheMap */);
}
weightChildNode(traverseSession, childDicNode); weightChildNode(traverseSession, childDicNode);
if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) { if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) {

View file

@ -21,4 +21,39 @@
namespace latinime { namespace latinime {
const TypingWeighting TypingWeighting::sInstance; const TypingWeighting TypingWeighting::sInstance;
ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType,
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
switch (correctionType) {
case CT_MATCH:
if (isProximityDicNode(traverseSession, dicNode)) {
return ET_PROXIMITY_CORRECTION;
} else {
return ET_NOT_AN_ERROR;
}
case CT_ADDITIONAL_PROXIMITY:
return ET_PROXIMITY_CORRECTION;
case CT_OMISSION:
if (parentDicNode->canBeIntentionalOmission()) {
return ET_INTENTIONAL_OMISSION;
} else {
return ET_EDIT_CORRECTION;
}
break;
case CT_SUBSTITUTION:
case CT_INSERTION:
case CT_TRANSPOSITION:
return ET_EDIT_CORRECTION;
case CT_NEW_WORD_SPACE_OMITTION:
case CT_NEW_WORD_SPACE_SUBSTITUTION:
return ET_NEW_WORD;
case CT_TERMINAL:
return ET_NOT_AN_ERROR;
case CT_COMPLETION:
return ET_COMPLETION;
default:
return ET_NOT_AN_ERROR;
}
}
} // namespace latinime } // namespace latinime

View file

@ -50,13 +50,14 @@ class TypingWeighting : public Weighting {
} }
float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
bool sameCodePoint = false; const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
bool isFirstLetterOmission = false; const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
float cost = 0.0f;
sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
// If the traversal omitted the first letter then the dicNode should now be on the second. // If the traversal omitted the first letter then the dicNode should now be on the second.
isFirstLetterOmission = dicNode->getDepth() == 2; const bool isFirstLetterOmission = dicNode->getDepth() == 2;
if (isFirstLetterOmission) { float cost = 0.0f;
if (isZeroCostOmission) {
cost = 0.0f;
} else if (isFirstLetterOmission) {
cost = ScoringParams::OMISSION_COST_FIRST_CHAR; cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
} else { } else {
cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
@ -156,15 +157,8 @@ class TypingWeighting : public Weighting {
float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
const bool hasEditCount = dicNode->getEditCorrectionCount() > 0; const float languageImprobability = (dicNode->isExactMatch()) ?
const bool isSameLength = dicNode->getDepth() == traverseSession->getInputSize(); 0.0f : dicNodeLanguageImprobability;
const bool hasMultipleWords = dicNode->hasMultipleWords();
const bool hasProximityErrors = dicNode->getProximityCorrectionCount() > 0;
// Gesture input is always assumed to have proximity errors
// because the input word shouldn't be treated as perfect
const bool isExactMatch = !hasEditCount && !hasMultipleWords
&& !hasProximityErrors && isSameLength;
const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability;
return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
} }
@ -189,6 +183,10 @@ class TypingWeighting : public Weighting {
return cost * traverseSession->getMultiWordCostMultiplier(); return cost * traverseSession->getMultiWordCostMultiplier();
} }
ErrorType getErrorType(const CorrectionType correctionType,
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const;
private: private:
DISALLOW_COPY_AND_ASSIGN(TypingWeighting); DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
static const TypingWeighting sInstance; static const TypingWeighting sInstance;