Merge "Differentiate exact matches' minor errors."

This commit is contained in:
Keisuke Kuroyanagi 2013-12-19 04:14:11 +00:00 committed by Android (Google) Code Review
commit b68dd6cd0b
8 changed files with 55 additions and 25 deletions

View file

@ -23,6 +23,7 @@
#include "suggest/core/dicnode/internal/dic_node_state.h" #include "suggest/core/dicnode/internal/dic_node_state.h"
#include "suggest/core/dicnode/internal/dic_node_properties.h" #include "suggest/core/dicnode/internal/dic_node_properties.h"
#include "suggest/core/dictionary/digraph_utils.h" #include "suggest/core/dictionary/digraph_utils.h"
#include "suggest/core/dictionary/error_type_utils.h"
#include "utils/char_utils.h" #include "utils/char_utils.h"
#if DEBUG_DICT #if DEBUG_DICT
@ -493,8 +494,8 @@ class DicNode {
mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
} }
bool isExactMatch() const { ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
return mDicNodeState.mDicNodeStateScoring.isExactMatch(); return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes();
} }
bool isBlacklistedOrNotAWord() const { bool isBlacklistedOrNotAWord() const {
@ -535,8 +536,8 @@ class DicNode {
return false; return false;
} }
// Promote exact matches to prevent them from being pruned. // Promote exact matches to prevent them from being pruned.
const bool leftExactMatch = isExactMatch(); const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes());
const bool rightExactMatch = right->isExactMatch(); const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes());
if (leftExactMatch != rightExactMatch) { if (leftExactMatch != rightExactMatch) {
return leftExactMatch; return leftExactMatch;
} }

View file

@ -32,7 +32,7 @@ class DicNodeStateScoring {
mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
mEditCorrectionCount(0), mProximityCorrectionCount(0), mEditCorrectionCount(0), mProximityCorrectionCount(0),
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
mRawLength(0.0f), mContainingErrorTypes(ErrorTypeUtils::NOT_AN_ERROR), mRawLength(0.0f), mContainedErrorTypes(ErrorTypeUtils::NOT_AN_ERROR),
mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) { mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) {
} }
@ -48,7 +48,7 @@ class DicNodeStateScoring {
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING; mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING;
mContainingErrorTypes = ErrorTypeUtils::NOT_AN_ERROR; mContainedErrorTypes = ErrorTypeUtils::NOT_AN_ERROR;
} }
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) { AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
@ -60,7 +60,7 @@ class DicNodeStateScoring {
mRawLength = scoring->mRawLength; mRawLength = scoring->mRawLength;
mDoubleLetterLevel = scoring->mDoubleLetterLevel; mDoubleLetterLevel = scoring->mDoubleLetterLevel;
mDigraphIndex = scoring->mDigraphIndex; mDigraphIndex = scoring->mDigraphIndex;
mContainingErrorTypes = scoring->mContainingErrorTypes; mContainedErrorTypes = scoring->mContainedErrorTypes;
mNormalizedCompoundDistanceAfterFirstWord = mNormalizedCompoundDistanceAfterFirstWord =
scoring->mNormalizedCompoundDistanceAfterFirstWord; scoring->mNormalizedCompoundDistanceAfterFirstWord;
} }
@ -69,7 +69,7 @@ class DicNodeStateScoring {
const int inputSize, const int totalInputIndex, const int inputSize, const int totalInputIndex,
const ErrorTypeUtils::ErrorType errorType) { const ErrorTypeUtils::ErrorType errorType) {
addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex); addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
mContainingErrorTypes = mContainingErrorTypes | errorType; mContainedErrorTypes = mContainedErrorTypes | errorType;
if (ErrorTypeUtils::isEditCorrectionError(errorType)) { if (ErrorTypeUtils::isEditCorrectionError(errorType)) {
++mEditCorrectionCount; ++mEditCorrectionCount;
} }
@ -169,8 +169,8 @@ class DicNodeStateScoring {
} }
} }
bool isExactMatch() const { ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
return ErrorTypeUtils::isExactMatch(mContainingErrorTypes); return mContainedErrorTypes;
} }
private: private:
@ -188,7 +188,7 @@ class DicNodeStateScoring {
float mLanguageDistance; float mLanguageDistance;
float mRawLength; float mRawLength;
// All accumulated error types so far // All accumulated error types so far
ErrorTypeUtils::ErrorType mContainingErrorTypes; ErrorTypeUtils::ErrorType mContainedErrorTypes;
float mNormalizedCompoundDistanceAfterFirstWord; float mNormalizedCompoundDistanceAfterFirstWord;
AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,

View file

@ -47,9 +47,8 @@ class ErrorTypeUtils {
// A new word error should be an edit correction error or a proximity correction error. // A new word error should be an edit correction error or a proximity correction error.
static const ErrorType NEW_WORD; static const ErrorType NEW_WORD;
// TODO: Differentiate errors. static bool isExactMatch(const ErrorType containedErrorTypes) {
static bool isExactMatch(const ErrorType containingErrors) { return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0;
return (containingErrors & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0;
} }
static bool isEditCorrectionError(const ErrorType errorType) { static bool isEditCorrectionError(const ErrorType errorType) {

View file

@ -18,8 +18,9 @@
#include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_utils.h" #include "suggest/core/dicnode/dic_node_utils.h"
#include "suggest/core/dictionary/dictionary.h"
#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h" #include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h"
#include "suggest/core/dictionary/dictionary.h"
#include "suggest/core/dictionary/error_type_utils.h"
#include "suggest/core/policy/scoring.h" #include "suggest/core/policy/scoring.h"
#include "suggest/core/session/dic_traverse_session.h" #include "suggest/core/session/dic_traverse_session.h"
@ -98,7 +99,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
const bool isPossiblyOffensiveWord = const bool isPossiblyOffensiveWord =
traverseSession->getDictionaryStructurePolicy()->getProbability( traverseSession->getDictionaryStructurePolicy()->getProbability(
terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0; terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0;
const bool isExactMatch = terminalDicNode->isExactMatch(); const bool isExactMatch =
ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes());
const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase();
// Heuristic: We exclude freq=0 first-char-uppercase words from exact match. // Heuristic: We exclude freq=0 first-char-uppercase words from exact match.
// (e.g. "AMD" and "and") // (e.g. "AMD" and "and")
@ -115,9 +117,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
// TODO: Better integration with java side autocorrection logic. // TODO: Better integration with java side autocorrection logic.
const int finalScore = scoringPolicy->calculateFinalScore( const int finalScore = scoringPolicy->calculateFinalScore(
compoundDistance, traverseSession->getInputSize(), compoundDistance, traverseSession->getInputSize(),
terminalDicNode->isExactMatch() terminalDicNode->getContainedErrorTypes(),
|| (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) (forceCommitMultiWords && terminalDicNode->hasMultipleWords())
|| (isValidWord && scoringPolicy->doesAutoCorrectValidWord())); || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()));
if (maxScore < finalScore && isValidWord) { if (maxScore < finalScore && isValidWord) {
maxScore = finalScore; maxScore = finalScore;
} }
@ -149,7 +151,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode); const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode);
const int shortcutBaseScore = scoringPolicy->doesAutoCorrectValidWord() ? const int shortcutBaseScore = scoringPolicy->doesAutoCorrectValidWord() ?
scoringPolicy->calculateFinalScore(compoundDistance, scoringPolicy->calculateFinalScore(compoundDistance,
traverseSession->getInputSize(), true /* forceCommit */) : finalScore; traverseSession->getInputSize(),
terminalDicNode->getContainedErrorTypes(),
true /* forceCommit */) : finalScore;
const int updatedOutputWordIndex = outputShortcuts(&shortcutIt, const int updatedOutputWordIndex = outputShortcuts(&shortcutIt,
outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes, outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes,
sameAsTyped); sameAsTyped);

View file

@ -28,7 +28,7 @@ class DicTraverseSession;
class Scoring { class Scoring {
public: public:
virtual int calculateFinalScore(const float compoundDistance, const int inputSize, virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
const bool forceCommit) const = 0; const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit) const = 0;
virtual bool getMostProbableString(const DicTraverseSession *const traverseSession, virtual bool getMostProbableString(const DicTraverseSession *const traverseSession,
const int terminalSize, const float languageWeight, int *const outputCodePoints, const int terminalSize, const float languageWeight, int *const outputCodePoints,
int *const type, int *const freq) const = 0; int *const type, int *const freq) const = 0;

View file

@ -22,6 +22,12 @@ const float ScoringParams::MAX_SPATIAL_DISTANCE = 1.0f;
const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY = 40; const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY = 40;
const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120; const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120;
const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f; const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f;
const float ScoringParams::EXACT_MATCH_PROMOTION = 1.1f;
const float ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH = 0.01f;
const float ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH = 0.02f;
const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f;
// TODO: Unlimit max cache dic node size // TODO: Unlimit max cache dic node size
const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170; const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170;
const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310; const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310;

View file

@ -32,6 +32,11 @@ class ScoringParams {
static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT; static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT;
static const int THRESHOLD_SHORT_WORD_LENGTH; static const int THRESHOLD_SHORT_WORD_LENGTH;
static const float EXACT_MATCH_PROMOTION;
static const float CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
static const float ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
static const float DIGRAPH_PENALTY_FOR_EXACT_MATCH;
// Numerically optimized parameters (currently for tap typing only). // Numerically optimized parameters (currently for tap typing only).
// TODO: add ability to modify these constants programmatically. // TODO: add ability to modify these constants programmatically.
// TODO: explore optimization of gesture parameters. // TODO: explore optimization of gesture parameters.

View file

@ -18,6 +18,7 @@
#define LATINIME_TYPING_SCORING_H #define LATINIME_TYPING_SCORING_H
#include "defines.h" #include "defines.h"
#include "suggest/core/dictionary/error_type_utils.h"
#include "suggest/core/policy/scoring.h" #include "suggest/core/policy/scoring.h"
#include "suggest/core/session/dic_traverse_session.h" #include "suggest/core/session/dic_traverse_session.h"
#include "suggest/policyimpl/typing/scoring_params.h" #include "suggest/policyimpl/typing/scoring_params.h"
@ -53,12 +54,26 @@ class TypingScoring : public Scoring {
} }
AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance,
const int inputSize, const bool forceCommit) const { const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes,
const bool forceCommit) const {
const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE
+ static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; + static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
const float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance;
- compoundDistance / maxDistance if (forceCommit) {
+ (forceCommit ? ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD : 0.0f); score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD;
}
if (ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
score += ScoringParams::EXACT_MATCH_PROMOTION;
if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) {
score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
}
if ((ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR & containedErrorTypes) != 0) {
score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
}
if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) {
score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH;
}
}
return static_cast<int>(score * SUGGEST_INTERFACE_OUTPUT_SCALE); return static_cast<int>(score * SUGGEST_INTERFACE_OUTPUT_SCALE);
} }