diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp index 1e2494e92..8f07ce275 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp @@ -31,6 +31,7 @@ const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x100; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH = NOT_AN_ERROR | MATCH_WITH_WRONG_CASE | MATCH_WITH_MISSING_ACCENT | MATCH_WITH_DIGRAPH; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_A_PERFECT_MATCH = NOT_AN_ERROR; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION = diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.h b/native/jni/src/suggest/core/dictionary/error_type_utils.h index fd1d5fcff..e92c509fa 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.h +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.h @@ -52,6 +52,10 @@ class ErrorTypeUtils { return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0; } + static bool isPerfectMatch(const ErrorType containedErrorTypes) { + return (containedErrorTypes & ~ERRORS_TREATED_AS_A_PERFECT_MATCH) == 0; + } + static bool isExactMatchWithIntentionalOmission(const ErrorType containedErrorTypes) { return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION) == 0; @@ -73,6 +77,7 @@ class ErrorTypeUtils { DISALLOW_IMPLICIT_CONSTRUCTORS(ErrorTypeUtils); static const ErrorType ERRORS_TREATED_AS_AN_EXACT_MATCH; + static const ErrorType ERRORS_TREATED_AS_A_PERFECT_MATCH; static const ErrorType ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION; }; } // namespace latinime diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index ce3684a1c..b9dda83ad 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -30,7 +30,7 @@ class Scoring { public: virtual int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, - const bool boostExactMatches) const = 0; + const bool boostExactMatches, const bool hasProbabilityZero) const = 0; virtual void getMostProbableString(const DicTraverseSession *const traverseSession, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const = 0; diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index 23103b9f7..74db95953 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -161,7 +161,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), (forceCommitMultiWords && terminalDicNode->hasMultipleWords()), - boostExactMatches); + boostExactMatches, wordAttributes.getProbability() == 0); // Don't output invalid or blocked offensive words. However, we still need to submit their // shortcuts if any. diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index a6f9a8b23..856808a74 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -24,6 +24,7 @@ const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120; const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f; const float ScoringParams::EXACT_MATCH_PROMOTION = 1.1f; +const float ScoringParams::PERFECT_MATCH_PROMOTION = 1.1f; const float ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH = 0.01f; const float ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH = 0.02f; const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index b8f889559..6f327a370 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -34,6 +34,7 @@ class ScoringParams { static const int THRESHOLD_SHORT_WORD_LENGTH; static const float EXACT_MATCH_PROMOTION; + static const float PERFECT_MATCH_PROMOTION; static const float CASE_ERROR_PENALTY_FOR_EXACT_MATCH; static const float ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; static const float DIGRAPH_PENALTY_FOR_EXACT_MATCH; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 0240bcf54..6acd767ea 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -44,23 +44,50 @@ class TypingScoring : public Scoring { AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, - const bool boostExactMatches) const { + const bool boostExactMatches, const bool hasProbabilityZero) const { const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + static_cast(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance; if (forceCommit) { score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD; } - if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { - score += ScoringParams::EXACT_MATCH_PROMOTION; - if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) { - score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; + if (hasProbabilityZero) { + // Previously, when both legitimate 0-frequency words (such as distracters) and + // offensive words were encoded in the same way, distracters would never show up + // when the user blocked offensive words (the default setting, as well as the + // setting for regression tests). + // + // When b/11031090 was fixed and a separate encoding was used for offensive words, + // 0-frequency words would no longer be blocked when they were an "exact match" + // (where case mismatches and accent mismatches would be considered an "exact + // match"). The exact match boosting functionality meant that, for example, when + // the user typed "mt" they would be suggested the word "Mt", although they most + // probably meant to type "my". + // + // For this reason, we introduced this change, which does the following: + // * Defines the "perfect match" as a really exact match, with no room for case or + // accent mismatches + // * When the target word has probability zero (as "Mt" does, because it is a + // distracter), ONLY boost its score if it is a perfect match. + // + // By doing this, when the user types "mt", the word "Mt" will NOT be boosted, and + // they will get "my". However, if the user makes an explicit effort to type "Mt", + // we do boost the word "Mt" so that the user's input is not autocorrected to "My". + if (boostExactMatches && ErrorTypeUtils::isPerfectMatch(containedErrorTypes)) { + score += ScoringParams::PERFECT_MATCH_PROMOTION; } - if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) { - score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; - } - if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { - score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH; + } else { + if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { + score += ScoringParams::EXACT_MATCH_PROMOTION; + if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) { + score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; + } + if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) { + score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; + } + if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { + score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH; + } } } return static_cast(score * SUGGEST_INTERFACE_OUTPUT_SCALE);