Merge "Differentiate exact matches' minor errors."
This commit is contained in:
commit
b68dd6cd0b
8 changed files with 55 additions and 25 deletions
|
@ -23,6 +23,7 @@
|
||||||
#include "suggest/core/dicnode/internal/dic_node_state.h"
|
#include "suggest/core/dicnode/internal/dic_node_state.h"
|
||||||
#include "suggest/core/dicnode/internal/dic_node_properties.h"
|
#include "suggest/core/dicnode/internal/dic_node_properties.h"
|
||||||
#include "suggest/core/dictionary/digraph_utils.h"
|
#include "suggest/core/dictionary/digraph_utils.h"
|
||||||
|
#include "suggest/core/dictionary/error_type_utils.h"
|
||||||
#include "utils/char_utils.h"
|
#include "utils/char_utils.h"
|
||||||
|
|
||||||
#if DEBUG_DICT
|
#if DEBUG_DICT
|
||||||
|
@ -493,8 +494,8 @@ class DicNode {
|
||||||
mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
|
mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isExactMatch() const {
|
ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
|
||||||
return mDicNodeState.mDicNodeStateScoring.isExactMatch();
|
return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isBlacklistedOrNotAWord() const {
|
bool isBlacklistedOrNotAWord() const {
|
||||||
|
@ -535,8 +536,8 @@ class DicNode {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Promote exact matches to prevent them from being pruned.
|
// Promote exact matches to prevent them from being pruned.
|
||||||
const bool leftExactMatch = isExactMatch();
|
const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes());
|
||||||
const bool rightExactMatch = right->isExactMatch();
|
const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes());
|
||||||
if (leftExactMatch != rightExactMatch) {
|
if (leftExactMatch != rightExactMatch) {
|
||||||
return leftExactMatch;
|
return leftExactMatch;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ class DicNodeStateScoring {
|
||||||
mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
|
mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
|
||||||
mEditCorrectionCount(0), mProximityCorrectionCount(0),
|
mEditCorrectionCount(0), mProximityCorrectionCount(0),
|
||||||
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
|
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
|
||||||
mRawLength(0.0f), mContainingErrorTypes(ErrorTypeUtils::NOT_AN_ERROR),
|
mRawLength(0.0f), mContainedErrorTypes(ErrorTypeUtils::NOT_AN_ERROR),
|
||||||
mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) {
|
mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ class DicNodeStateScoring {
|
||||||
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
|
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
|
||||||
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
|
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
|
||||||
mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING;
|
mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING;
|
||||||
mContainingErrorTypes = ErrorTypeUtils::NOT_AN_ERROR;
|
mContainedErrorTypes = ErrorTypeUtils::NOT_AN_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
|
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
|
||||||
|
@ -60,7 +60,7 @@ class DicNodeStateScoring {
|
||||||
mRawLength = scoring->mRawLength;
|
mRawLength = scoring->mRawLength;
|
||||||
mDoubleLetterLevel = scoring->mDoubleLetterLevel;
|
mDoubleLetterLevel = scoring->mDoubleLetterLevel;
|
||||||
mDigraphIndex = scoring->mDigraphIndex;
|
mDigraphIndex = scoring->mDigraphIndex;
|
||||||
mContainingErrorTypes = scoring->mContainingErrorTypes;
|
mContainedErrorTypes = scoring->mContainedErrorTypes;
|
||||||
mNormalizedCompoundDistanceAfterFirstWord =
|
mNormalizedCompoundDistanceAfterFirstWord =
|
||||||
scoring->mNormalizedCompoundDistanceAfterFirstWord;
|
scoring->mNormalizedCompoundDistanceAfterFirstWord;
|
||||||
}
|
}
|
||||||
|
@ -69,7 +69,7 @@ class DicNodeStateScoring {
|
||||||
const int inputSize, const int totalInputIndex,
|
const int inputSize, const int totalInputIndex,
|
||||||
const ErrorTypeUtils::ErrorType errorType) {
|
const ErrorTypeUtils::ErrorType errorType) {
|
||||||
addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
|
addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
|
||||||
mContainingErrorTypes = mContainingErrorTypes | errorType;
|
mContainedErrorTypes = mContainedErrorTypes | errorType;
|
||||||
if (ErrorTypeUtils::isEditCorrectionError(errorType)) {
|
if (ErrorTypeUtils::isEditCorrectionError(errorType)) {
|
||||||
++mEditCorrectionCount;
|
++mEditCorrectionCount;
|
||||||
}
|
}
|
||||||
|
@ -169,8 +169,8 @@ class DicNodeStateScoring {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isExactMatch() const {
|
ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
|
||||||
return ErrorTypeUtils::isExactMatch(mContainingErrorTypes);
|
return mContainedErrorTypes;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -188,7 +188,7 @@ class DicNodeStateScoring {
|
||||||
float mLanguageDistance;
|
float mLanguageDistance;
|
||||||
float mRawLength;
|
float mRawLength;
|
||||||
// All accumulated error types so far
|
// All accumulated error types so far
|
||||||
ErrorTypeUtils::ErrorType mContainingErrorTypes;
|
ErrorTypeUtils::ErrorType mContainedErrorTypes;
|
||||||
float mNormalizedCompoundDistanceAfterFirstWord;
|
float mNormalizedCompoundDistanceAfterFirstWord;
|
||||||
|
|
||||||
AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
|
AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
|
||||||
|
|
|
@ -47,9 +47,8 @@ class ErrorTypeUtils {
|
||||||
// A new word error should be an edit correction error or a proximity correction error.
|
// A new word error should be an edit correction error or a proximity correction error.
|
||||||
static const ErrorType NEW_WORD;
|
static const ErrorType NEW_WORD;
|
||||||
|
|
||||||
// TODO: Differentiate errors.
|
static bool isExactMatch(const ErrorType containedErrorTypes) {
|
||||||
static bool isExactMatch(const ErrorType containingErrors) {
|
return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0;
|
||||||
return (containingErrors & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEditCorrectionError(const ErrorType errorType) {
|
static bool isEditCorrectionError(const ErrorType errorType) {
|
||||||
|
|
|
@ -18,8 +18,9 @@
|
||||||
|
|
||||||
#include "suggest/core/dicnode/dic_node.h"
|
#include "suggest/core/dicnode/dic_node.h"
|
||||||
#include "suggest/core/dicnode/dic_node_utils.h"
|
#include "suggest/core/dicnode/dic_node_utils.h"
|
||||||
#include "suggest/core/dictionary/dictionary.h"
|
|
||||||
#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h"
|
#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h"
|
||||||
|
#include "suggest/core/dictionary/dictionary.h"
|
||||||
|
#include "suggest/core/dictionary/error_type_utils.h"
|
||||||
#include "suggest/core/policy/scoring.h"
|
#include "suggest/core/policy/scoring.h"
|
||||||
#include "suggest/core/session/dic_traverse_session.h"
|
#include "suggest/core/session/dic_traverse_session.h"
|
||||||
|
|
||||||
|
@ -98,7 +99,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
|
||||||
const bool isPossiblyOffensiveWord =
|
const bool isPossiblyOffensiveWord =
|
||||||
traverseSession->getDictionaryStructurePolicy()->getProbability(
|
traverseSession->getDictionaryStructurePolicy()->getProbability(
|
||||||
terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0;
|
terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0;
|
||||||
const bool isExactMatch = terminalDicNode->isExactMatch();
|
const bool isExactMatch =
|
||||||
|
ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes());
|
||||||
const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase();
|
const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase();
|
||||||
// Heuristic: We exclude freq=0 first-char-uppercase words from exact match.
|
// Heuristic: We exclude freq=0 first-char-uppercase words from exact match.
|
||||||
// (e.g. "AMD" and "and")
|
// (e.g. "AMD" and "and")
|
||||||
|
@ -115,8 +117,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
|
||||||
// TODO: Better integration with java side autocorrection logic.
|
// TODO: Better integration with java side autocorrection logic.
|
||||||
const int finalScore = scoringPolicy->calculateFinalScore(
|
const int finalScore = scoringPolicy->calculateFinalScore(
|
||||||
compoundDistance, traverseSession->getInputSize(),
|
compoundDistance, traverseSession->getInputSize(),
|
||||||
terminalDicNode->isExactMatch()
|
terminalDicNode->getContainedErrorTypes(),
|
||||||
|| (forceCommitMultiWords && terminalDicNode->hasMultipleWords())
|
(forceCommitMultiWords && terminalDicNode->hasMultipleWords())
|
||||||
|| (isValidWord && scoringPolicy->doesAutoCorrectValidWord()));
|
|| (isValidWord && scoringPolicy->doesAutoCorrectValidWord()));
|
||||||
if (maxScore < finalScore && isValidWord) {
|
if (maxScore < finalScore && isValidWord) {
|
||||||
maxScore = finalScore;
|
maxScore = finalScore;
|
||||||
|
@ -149,7 +151,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
|
||||||
const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode);
|
const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode);
|
||||||
const int shortcutBaseScore = scoringPolicy->doesAutoCorrectValidWord() ?
|
const int shortcutBaseScore = scoringPolicy->doesAutoCorrectValidWord() ?
|
||||||
scoringPolicy->calculateFinalScore(compoundDistance,
|
scoringPolicy->calculateFinalScore(compoundDistance,
|
||||||
traverseSession->getInputSize(), true /* forceCommit */) : finalScore;
|
traverseSession->getInputSize(),
|
||||||
|
terminalDicNode->getContainedErrorTypes(),
|
||||||
|
true /* forceCommit */) : finalScore;
|
||||||
const int updatedOutputWordIndex = outputShortcuts(&shortcutIt,
|
const int updatedOutputWordIndex = outputShortcuts(&shortcutIt,
|
||||||
outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes,
|
outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes,
|
||||||
sameAsTyped);
|
sameAsTyped);
|
||||||
|
|
|
@ -28,7 +28,7 @@ class DicTraverseSession;
|
||||||
class Scoring {
|
class Scoring {
|
||||||
public:
|
public:
|
||||||
virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
|
virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
|
||||||
const bool forceCommit) const = 0;
|
const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit) const = 0;
|
||||||
virtual bool getMostProbableString(const DicTraverseSession *const traverseSession,
|
virtual bool getMostProbableString(const DicTraverseSession *const traverseSession,
|
||||||
const int terminalSize, const float languageWeight, int *const outputCodePoints,
|
const int terminalSize, const float languageWeight, int *const outputCodePoints,
|
||||||
int *const type, int *const freq) const = 0;
|
int *const type, int *const freq) const = 0;
|
||||||
|
|
|
@ -22,6 +22,12 @@ const float ScoringParams::MAX_SPATIAL_DISTANCE = 1.0f;
|
||||||
const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY = 40;
|
const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY = 40;
|
||||||
const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120;
|
const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120;
|
||||||
const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f;
|
const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f;
|
||||||
|
|
||||||
|
const float ScoringParams::EXACT_MATCH_PROMOTION = 1.1f;
|
||||||
|
const float ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH = 0.01f;
|
||||||
|
const float ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH = 0.02f;
|
||||||
|
const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f;
|
||||||
|
|
||||||
// TODO: Unlimit max cache dic node size
|
// TODO: Unlimit max cache dic node size
|
||||||
const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170;
|
const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170;
|
||||||
const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310;
|
const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310;
|
||||||
|
|
|
@ -32,6 +32,11 @@ class ScoringParams {
|
||||||
static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT;
|
static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT;
|
||||||
static const int THRESHOLD_SHORT_WORD_LENGTH;
|
static const int THRESHOLD_SHORT_WORD_LENGTH;
|
||||||
|
|
||||||
|
static const float EXACT_MATCH_PROMOTION;
|
||||||
|
static const float CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
|
||||||
|
static const float ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
|
||||||
|
static const float DIGRAPH_PENALTY_FOR_EXACT_MATCH;
|
||||||
|
|
||||||
// Numerically optimized parameters (currently for tap typing only).
|
// Numerically optimized parameters (currently for tap typing only).
|
||||||
// TODO: add ability to modify these constants programmatically.
|
// TODO: add ability to modify these constants programmatically.
|
||||||
// TODO: explore optimization of gesture parameters.
|
// TODO: explore optimization of gesture parameters.
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define LATINIME_TYPING_SCORING_H
|
#define LATINIME_TYPING_SCORING_H
|
||||||
|
|
||||||
#include "defines.h"
|
#include "defines.h"
|
||||||
|
#include "suggest/core/dictionary/error_type_utils.h"
|
||||||
#include "suggest/core/policy/scoring.h"
|
#include "suggest/core/policy/scoring.h"
|
||||||
#include "suggest/core/session/dic_traverse_session.h"
|
#include "suggest/core/session/dic_traverse_session.h"
|
||||||
#include "suggest/policyimpl/typing/scoring_params.h"
|
#include "suggest/policyimpl/typing/scoring_params.h"
|
||||||
|
@ -53,12 +54,26 @@ class TypingScoring : public Scoring {
|
||||||
}
|
}
|
||||||
|
|
||||||
AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance,
|
AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance,
|
||||||
const int inputSize, const bool forceCommit) const {
|
const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes,
|
||||||
|
const bool forceCommit) const {
|
||||||
const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE
|
const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE
|
||||||
+ static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
|
+ static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
|
||||||
const float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE
|
float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance;
|
||||||
- compoundDistance / maxDistance
|
if (forceCommit) {
|
||||||
+ (forceCommit ? ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD : 0.0f);
|
score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD;
|
||||||
|
}
|
||||||
|
if (ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
|
||||||
|
score += ScoringParams::EXACT_MATCH_PROMOTION;
|
||||||
|
if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) {
|
||||||
|
score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
|
||||||
|
}
|
||||||
|
if ((ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR & containedErrorTypes) != 0) {
|
||||||
|
score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
|
||||||
|
}
|
||||||
|
if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) {
|
||||||
|
score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH;
|
||||||
|
}
|
||||||
|
}
|
||||||
return static_cast<int>(score * SUGGEST_INTERFACE_OUTPUT_SCALE);
|
return static_cast<int>(score * SUGGEST_INTERFACE_OUTPUT_SCALE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue