Merge "Fix exact match checking for words with digraph."
This commit is contained in:
commit
3261746ff9
9 changed files with 115 additions and 102 deletions
|
@ -438,4 +438,24 @@ typedef enum {
|
|||
// Create new word with space substitution
|
||||
CT_NEW_WORD_SPACE_SUBSTITUTION,
|
||||
} CorrectionType;
|
||||
|
||||
// ErrorType is mainly decided by CorrectionType but it is also depending on if
|
||||
// the correction has really been performed or not.
|
||||
typedef enum {
|
||||
// Substitution, omission and transposition
|
||||
ET_EDIT_CORRECTION,
|
||||
// Proximity error
|
||||
ET_PROXIMITY_CORRECTION,
|
||||
// Completion
|
||||
ET_COMPLETION,
|
||||
// New word
|
||||
// TODO: Remove.
|
||||
// A new word error should be an edit correction error or a proximity correction error.
|
||||
ET_NEW_WORD,
|
||||
// Treat error as an intentional omission when the CorrectionType is omission and the node can
|
||||
// be intentional omission.
|
||||
ET_INTENTIONAL_OMISSION,
|
||||
// Not treated as an error. Tracked for checking exact match
|
||||
ET_NOT_AN_ERROR
|
||||
} ErrorType;
|
||||
#endif // LATINIME_DEFINES_H
|
||||
|
|
|
@ -463,6 +463,10 @@ class DicNode {
|
|||
mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
|
||||
}
|
||||
|
||||
bool isExactMatch() const {
|
||||
return mDicNodeState.mDicNodeStateScoring.isExactMatch();
|
||||
}
|
||||
|
||||
uint8_t getFlags() const {
|
||||
return mDicNodeProperties.getFlags();
|
||||
}
|
||||
|
@ -542,13 +546,12 @@ class DicNode {
|
|||
// Caveat: Must not be called outside Weighting
|
||||
// This restriction is guaranteed by "friend"
|
||||
AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
|
||||
const bool doNormalization, const int inputSize, const bool isEditCorrection,
|
||||
const bool isProximityCorrection) {
|
||||
const bool doNormalization, const int inputSize, const ErrorType errorType) {
|
||||
if (DEBUG_GEO_FULL) {
|
||||
LOGI_SHOW_ADD_COST_PROP;
|
||||
}
|
||||
mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
|
||||
inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection);
|
||||
inputSize, getTotalInputIndex(), errorType);
|
||||
}
|
||||
|
||||
// Caveat: Must not be called outside Weighting
|
||||
|
|
|
@ -46,8 +46,8 @@ class DicNodeStateInput {
|
|||
for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
|
||||
mInputIndex[i] = src->mInputIndex[i];
|
||||
mPrevCodePoint[i] = src->mPrevCodePoint[i];
|
||||
mTerminalDiffCost[i] = resetTerminalDiffCost ?
|
||||
static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i];
|
||||
mTerminalDiffCost[i] = resetTerminalDiffCost ?
|
||||
static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class DicNodeStateScoring {
|
|||
mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
|
||||
mEditCorrectionCount(0), mProximityCorrectionCount(0),
|
||||
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
|
||||
mRawLength(0.0f) {
|
||||
mRawLength(0.0f), mExactMatch(true) {
|
||||
}
|
||||
|
||||
virtual ~DicNodeStateScoring() {}
|
||||
|
@ -45,6 +45,7 @@ class DicNodeStateScoring {
|
|||
mRawLength = 0.0f;
|
||||
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
|
||||
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
|
||||
mExactMatch = true;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
|
||||
|
@ -56,17 +57,32 @@ class DicNodeStateScoring {
|
|||
mRawLength = scoring->mRawLength;
|
||||
mDoubleLetterLevel = scoring->mDoubleLetterLevel;
|
||||
mDigraphIndex = scoring->mDigraphIndex;
|
||||
mExactMatch = scoring->mExactMatch;
|
||||
}
|
||||
|
||||
void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
|
||||
const int inputSize, const int totalInputIndex, const bool isEditCorrection,
|
||||
const bool isProximityCorrection) {
|
||||
const int inputSize, const int totalInputIndex, const ErrorType errorType) {
|
||||
addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
|
||||
if (isEditCorrection) {
|
||||
++mEditCorrectionCount;
|
||||
}
|
||||
if (isProximityCorrection) {
|
||||
++mProximityCorrectionCount;
|
||||
switch (errorType) {
|
||||
case ET_EDIT_CORRECTION:
|
||||
++mEditCorrectionCount;
|
||||
mExactMatch = false;
|
||||
break;
|
||||
case ET_PROXIMITY_CORRECTION:
|
||||
++mProximityCorrectionCount;
|
||||
mExactMatch = false;
|
||||
break;
|
||||
case ET_COMPLETION:
|
||||
mExactMatch = false;
|
||||
break;
|
||||
case ET_NEW_WORD:
|
||||
mExactMatch = false;
|
||||
break;
|
||||
case ET_INTENTIONAL_OMISSION:
|
||||
mExactMatch = false;
|
||||
break;
|
||||
case ET_NOT_AN_ERROR:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,6 +159,10 @@ class DicNodeStateScoring {
|
|||
}
|
||||
}
|
||||
|
||||
bool isExactMatch() const {
|
||||
return mExactMatch;
|
||||
}
|
||||
|
||||
private:
|
||||
// Caution!!!
|
||||
// Use a default copy constructor and an assign operator because shallow copies are ok
|
||||
|
@ -157,6 +177,7 @@ class DicNodeStateScoring {
|
|||
float mSpatialDistance;
|
||||
float mLanguageDistance;
|
||||
float mRawLength;
|
||||
bool mExactMatch;
|
||||
|
||||
AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
|
||||
bool doNormalization, int inputSize, int totalInputIndex) {
|
||||
|
|
|
@ -80,9 +80,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
traverseSession, parentDicNode, dicNode, &inputStateG);
|
||||
const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
|
||||
traverseSession, parentDicNode, dicNode, bigramCacheMap);
|
||||
const bool edit = Weighting::isEditCorrection(correctionType);
|
||||
const bool proximity = Weighting::isProximityCorrection(weighting, correctionType,
|
||||
traverseSession, dicNode);
|
||||
const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession,
|
||||
parentDicNode, dicNode);
|
||||
profile(correctionType, dicNode);
|
||||
if (inputStateG.mNeedsToUpdateInputStateG) {
|
||||
dicNode->updateInputIndexG(&inputStateG);
|
||||
|
@ -91,7 +90,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
(correctionType == CT_TRANSPOSITION));
|
||||
}
|
||||
dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
|
||||
inputSize, edit, proximity);
|
||||
inputSize, errorType);
|
||||
}
|
||||
|
||||
/* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
|
||||
|
@ -158,62 +157,6 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
}
|
||||
}
|
||||
|
||||
/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) {
|
||||
switch(correctionType) {
|
||||
case CT_OMISSION:
|
||||
return true;
|
||||
case CT_ADDITIONAL_PROXIMITY:
|
||||
return true;
|
||||
case CT_SUBSTITUTION:
|
||||
return true;
|
||||
case CT_NEW_WORD_SPACE_OMITTION:
|
||||
return false;
|
||||
case CT_MATCH:
|
||||
return false;
|
||||
case CT_COMPLETION:
|
||||
return false;
|
||||
case CT_TERMINAL:
|
||||
return false;
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
return false;
|
||||
case CT_INSERTION:
|
||||
return true;
|
||||
case CT_TRANSPOSITION:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting,
|
||||
const CorrectionType correctionType,
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) {
|
||||
switch(correctionType) {
|
||||
case CT_OMISSION:
|
||||
return false;
|
||||
case CT_ADDITIONAL_PROXIMITY:
|
||||
return true;
|
||||
case CT_SUBSTITUTION:
|
||||
return false;
|
||||
case CT_NEW_WORD_SPACE_OMITTION:
|
||||
return false;
|
||||
case CT_MATCH:
|
||||
return weighting->isProximityDicNode(traverseSession, dicNode);
|
||||
case CT_COMPLETION:
|
||||
return false;
|
||||
case CT_TERMINAL:
|
||||
return false;
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
return false;
|
||||
case CT_INSERTION:
|
||||
return false;
|
||||
case CT_TRANSPOSITION:
|
||||
return false;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
|
||||
switch(correctionType) {
|
||||
case CT_OMISSION:
|
||||
|
|
|
@ -80,6 +80,10 @@ class Weighting {
|
|||
virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const = 0;
|
||||
|
||||
virtual ErrorType getErrorType(const CorrectionType correctionType,
|
||||
const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
|
||||
|
||||
Weighting() {}
|
||||
virtual ~Weighting() {}
|
||||
|
||||
|
@ -95,12 +99,6 @@ class Weighting {
|
|||
const DicNode *const parentDicNode, const DicNode *const dicNode,
|
||||
hash_map_compat<int, int16_t> *const bigramCacheMap);
|
||||
// TODO: Move to TypingWeighting and GestureWeighting?
|
||||
static bool isEditCorrection(const CorrectionType correctionType);
|
||||
// TODO: Move to TypingWeighting and GestureWeighting?
|
||||
static bool isProximityCorrection(const Weighting *const weighting,
|
||||
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode);
|
||||
// TODO: Move to TypingWeighting and GestureWeighting?
|
||||
static int getForwardInputCount(const CorrectionType correctionType);
|
||||
};
|
||||
} // namespace latinime
|
||||
|
|
|
@ -422,20 +422,15 @@ void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession,
|
|||
*/
|
||||
void Suggest::processDicNodeAsOmission(
|
||||
DicTraverseSession *traverseSession, DicNode *dicNode) const {
|
||||
// If the omission is surely intentional that it should incur zero cost.
|
||||
const bool isZeroCostOmission = dicNode->isZeroCostOmission();
|
||||
DicNodeVector childDicNodes;
|
||||
|
||||
DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes);
|
||||
|
||||
const int size = childDicNodes.getSizeAndLock();
|
||||
for (int i = 0; i < size; i++) {
|
||||
DicNode *const childDicNode = childDicNodes[i];
|
||||
if (!isZeroCostOmission) {
|
||||
// Treat this word as omission
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
|
||||
dicNode, childDicNode, 0 /* bigramCacheMap */);
|
||||
}
|
||||
// Treat this word as omission
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
|
||||
dicNode, childDicNode, 0 /* bigramCacheMap */);
|
||||
weightChildNode(traverseSession, childDicNode);
|
||||
|
||||
if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) {
|
||||
|
|
|
@ -21,4 +21,39 @@
|
|||
|
||||
namespace latinime {
|
||||
const TypingWeighting TypingWeighting::sInstance;
|
||||
|
||||
ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType,
|
||||
const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
||||
switch (correctionType) {
|
||||
case CT_MATCH:
|
||||
if (isProximityDicNode(traverseSession, dicNode)) {
|
||||
return ET_PROXIMITY_CORRECTION;
|
||||
} else {
|
||||
return ET_NOT_AN_ERROR;
|
||||
}
|
||||
case CT_ADDITIONAL_PROXIMITY:
|
||||
return ET_PROXIMITY_CORRECTION;
|
||||
case CT_OMISSION:
|
||||
if (parentDicNode->canBeIntentionalOmission()) {
|
||||
return ET_INTENTIONAL_OMISSION;
|
||||
} else {
|
||||
return ET_EDIT_CORRECTION;
|
||||
}
|
||||
break;
|
||||
case CT_SUBSTITUTION:
|
||||
case CT_INSERTION:
|
||||
case CT_TRANSPOSITION:
|
||||
return ET_EDIT_CORRECTION;
|
||||
case CT_NEW_WORD_SPACE_OMITTION:
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
return ET_NEW_WORD;
|
||||
case CT_TERMINAL:
|
||||
return ET_NOT_AN_ERROR;
|
||||
case CT_COMPLETION:
|
||||
return ET_COMPLETION;
|
||||
default:
|
||||
return ET_NOT_AN_ERROR;
|
||||
}
|
||||
}
|
||||
} // namespace latinime
|
||||
|
|
|
@ -50,13 +50,14 @@ class TypingWeighting : public Weighting {
|
|||
}
|
||||
|
||||
float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
||||
bool sameCodePoint = false;
|
||||
bool isFirstLetterOmission = false;
|
||||
float cost = 0.0f;
|
||||
sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
|
||||
const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
|
||||
const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
|
||||
// If the traversal omitted the first letter then the dicNode should now be on the second.
|
||||
isFirstLetterOmission = dicNode->getDepth() == 2;
|
||||
if (isFirstLetterOmission) {
|
||||
const bool isFirstLetterOmission = dicNode->getDepth() == 2;
|
||||
float cost = 0.0f;
|
||||
if (isZeroCostOmission) {
|
||||
cost = 0.0f;
|
||||
} else if (isFirstLetterOmission) {
|
||||
cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
|
||||
} else {
|
||||
cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
|
||||
|
@ -156,15 +157,8 @@ class TypingWeighting : public Weighting {
|
|||
|
||||
float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
|
||||
const bool hasEditCount = dicNode->getEditCorrectionCount() > 0;
|
||||
const bool isSameLength = dicNode->getDepth() == traverseSession->getInputSize();
|
||||
const bool hasMultipleWords = dicNode->hasMultipleWords();
|
||||
const bool hasProximityErrors = dicNode->getProximityCorrectionCount() > 0;
|
||||
// Gesture input is always assumed to have proximity errors
|
||||
// because the input word shouldn't be treated as perfect
|
||||
const bool isExactMatch = !hasEditCount && !hasMultipleWords
|
||||
&& !hasProximityErrors && isSameLength;
|
||||
const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability;
|
||||
const float languageImprobability = (dicNode->isExactMatch()) ?
|
||||
0.0f : dicNodeLanguageImprobability;
|
||||
return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
||||
}
|
||||
|
||||
|
@ -189,6 +183,10 @@ class TypingWeighting : public Weighting {
|
|||
return cost * traverseSession->getMultiWordCostMultiplier();
|
||||
}
|
||||
|
||||
ErrorType getErrorType(const CorrectionType correctionType,
|
||||
const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const;
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
|
||||
static const TypingWeighting sInstance;
|
||||
|
|
Loading…
Reference in a new issue