diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 069852d6e..558667eb0 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -23,6 +23,7 @@ #include "suggest/core/dicnode/internal/dic_node_state.h" #include "suggest/core/dicnode/internal/dic_node_properties.h" #include "suggest/core/dictionary/digraph_utils.h" +#include "suggest/core/dictionary/error_type_utils.h" #include "utils/char_utils.h" #if DEBUG_DICT @@ -493,8 +494,8 @@ class DicNode { mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); } - bool isExactMatch() const { - return mDicNodeState.mDicNodeStateScoring.isExactMatch(); + ErrorTypeUtils::ErrorType getContainedErrorTypes() const { + return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes(); } bool isBlacklistedOrNotAWord() const { @@ -535,8 +536,8 @@ class DicNode { return false; } // Promote exact matches to prevent them from being pruned. - const bool leftExactMatch = isExactMatch(); - const bool rightExactMatch = right->isExactMatch(); + const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes()); + const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes()); if (leftExactMatch != rightExactMatch) { return leftExactMatch; } diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h index 74f9eee92..11c201e52 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h @@ -32,7 +32,7 @@ class DicNodeStateScoring { mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), mEditCorrectionCount(0), mProximityCorrectionCount(0), mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), - mRawLength(0.0f), mContainingErrorTypes(ErrorTypeUtils::NOT_AN_ERROR), + mRawLength(0.0f), mContainedErrorTypes(ErrorTypeUtils::NOT_AN_ERROR), mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) { } @@ -48,7 +48,7 @@ class DicNodeStateScoring { mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING; - mContainingErrorTypes = ErrorTypeUtils::NOT_AN_ERROR; + mContainedErrorTypes = ErrorTypeUtils::NOT_AN_ERROR; } AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) { @@ -60,7 +60,7 @@ class DicNodeStateScoring { mRawLength = scoring->mRawLength; mDoubleLetterLevel = scoring->mDoubleLetterLevel; mDigraphIndex = scoring->mDigraphIndex; - mContainingErrorTypes = scoring->mContainingErrorTypes; + mContainedErrorTypes = scoring->mContainedErrorTypes; mNormalizedCompoundDistanceAfterFirstWord = scoring->mNormalizedCompoundDistanceAfterFirstWord; } @@ -69,7 +69,7 @@ class DicNodeStateScoring { const int inputSize, const int totalInputIndex, const ErrorTypeUtils::ErrorType errorType) { addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex); - mContainingErrorTypes = mContainingErrorTypes | errorType; + mContainedErrorTypes = mContainedErrorTypes | errorType; if (ErrorTypeUtils::isEditCorrectionError(errorType)) { ++mEditCorrectionCount; } @@ -169,8 +169,8 @@ class DicNodeStateScoring { } } - bool isExactMatch() const { - return ErrorTypeUtils::isExactMatch(mContainingErrorTypes); + ErrorTypeUtils::ErrorType getContainedErrorTypes() const { + return mContainedErrorTypes; } private: @@ -188,7 +188,7 @@ class DicNodeStateScoring { float mLanguageDistance; float mRawLength; // All accumulated error types so far - ErrorTypeUtils::ErrorType mContainingErrorTypes; + ErrorTypeUtils::ErrorType mContainedErrorTypes; float mNormalizedCompoundDistanceAfterFirstWord; AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.h b/native/jni/src/suggest/core/dictionary/error_type_utils.h index ab4a65e48..1122291a6 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.h +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.h @@ -47,9 +47,8 @@ class ErrorTypeUtils { // A new word error should be an edit correction error or a proximity correction error. static const ErrorType NEW_WORD; - // TODO: Differentiate errors. - static bool isExactMatch(const ErrorType containingErrors) { - return (containingErrors & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0; + static bool isExactMatch(const ErrorType containedErrorTypes) { + return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0; } static bool isEditCorrectionError(const ErrorType errorType) { diff --git a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp index a3ca7e748..d219757da 100644 --- a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp @@ -18,8 +18,9 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_utils.h" -#include "suggest/core/dictionary/dictionary.h" #include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h" +#include "suggest/core/dictionary/dictionary.h" +#include "suggest/core/dictionary/error_type_utils.h" #include "suggest/core/policy/scoring.h" #include "suggest/core/session/dic_traverse_session.h" @@ -98,7 +99,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; const bool isPossiblyOffensiveWord = traverseSession->getDictionaryStructurePolicy()->getProbability( terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0; - const bool isExactMatch = terminalDicNode->isExactMatch(); + const bool isExactMatch = + ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes()); const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); // Heuristic: We exclude freq=0 first-char-uppercase words from exact match. // (e.g. "AMD" and "and") @@ -115,9 +117,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; // TODO: Better integration with java side autocorrection logic. const int finalScore = scoringPolicy->calculateFinalScore( compoundDistance, traverseSession->getInputSize(), - terminalDicNode->isExactMatch() - || (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) - || (isValidWord && scoringPolicy->doesAutoCorrectValidWord())); + terminalDicNode->getContainedErrorTypes(), + (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) + || (isValidWord && scoringPolicy->doesAutoCorrectValidWord())); if (maxScore < finalScore && isValidWord) { maxScore = finalScore; } @@ -149,7 +151,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode); const int shortcutBaseScore = scoringPolicy->doesAutoCorrectValidWord() ? scoringPolicy->calculateFinalScore(compoundDistance, - traverseSession->getInputSize(), true /* forceCommit */) : finalScore; + traverseSession->getInputSize(), + terminalDicNode->getContainedErrorTypes(), + true /* forceCommit */) : finalScore; const int updatedOutputWordIndex = outputShortcuts(&shortcutIt, outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index a16412996..5ae3d2146 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -28,7 +28,7 @@ class DicTraverseSession; class Scoring { public: virtual int calculateFinalScore(const float compoundDistance, const int inputSize, - const bool forceCommit) const = 0; + const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit) const = 0; virtual bool getMostProbableString(const DicTraverseSession *const traverseSession, const int terminalSize, const float languageWeight, int *const outputCodePoints, int *const type, int *const freq) const = 0; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index 104eb2a7a..7b332064c 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -22,6 +22,12 @@ const float ScoringParams::MAX_SPATIAL_DISTANCE = 1.0f; const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY = 40; const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120; const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f; + +const float ScoringParams::EXACT_MATCH_PROMOTION = 1.1f; +const float ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH = 0.01f; +const float ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH = 0.02f; +const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f; + // TODO: Unlimit max cache dic node size const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170; const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index 7d4b5c3c7..de7410d39 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -32,6 +32,11 @@ class ScoringParams { static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT; static const int THRESHOLD_SHORT_WORD_LENGTH; + static const float EXACT_MATCH_PROMOTION; + static const float CASE_ERROR_PENALTY_FOR_EXACT_MATCH; + static const float ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; + static const float DIGRAPH_PENALTY_FOR_EXACT_MATCH; + // Numerically optimized parameters (currently for tap typing only). // TODO: add ability to modify these constants programmatically. // TODO: explore optimization of gesture parameters. diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index b1337c641..186e3ba08 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -18,6 +18,7 @@ #define LATINIME_TYPING_SCORING_H #include "defines.h" +#include "suggest/core/dictionary/error_type_utils.h" #include "suggest/core/policy/scoring.h" #include "suggest/core/session/dic_traverse_session.h" #include "suggest/policyimpl/typing/scoring_params.h" @@ -53,12 +54,26 @@ class TypingScoring : public Scoring { } AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, - const int inputSize, const bool forceCommit) const { + const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, + const bool forceCommit) const { const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + static_cast(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; - const float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - - compoundDistance / maxDistance - + (forceCommit ? ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD : 0.0f); + float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance; + if (forceCommit) { + score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD; + } + if (ErrorTypeUtils::isExactMatch(containedErrorTypes)) { + score += ScoringParams::EXACT_MATCH_PROMOTION; + if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) { + score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; + } + if ((ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR & containedErrorTypes) != 0) { + score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; + } + if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { + score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH; + } + } return static_cast(score * SUGGEST_INTERFACE_OUTPUT_SCALE); }