diff --git a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp index b8106377c..e37811b88 100644 --- a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp @@ -78,7 +78,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; outputAutoCommitFirstWordConfidence[0] = computeFirstWordConfidence(&terminals[0]); } - + const bool boostExactMatches = traverseSession->getDictionaryStructurePolicy()-> + getHeaderStructurePolicy()->shouldBoostExactMatches(); // Output suggestion results here for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; ++terminalIndex) { @@ -102,7 +103,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; && !(isPossiblyOffensiveWord && isFirstCharUppercase); const int outputTypeFlags = (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) - | (isSafeExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); + | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); // Entries that are blacklisted or do not represent a word should not be output. const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord(); @@ -113,7 +114,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) - || (isValidWord && scoringPolicy->doesAutoCorrectValidWord())); + || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()), + boostExactMatches); if (maxScore < finalScore && isValidWord) { maxScore = finalScore; } @@ -147,7 +149,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; scoringPolicy->calculateFinalScore(compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), - true /* forceCommit */) : finalScore; + true /* forceCommit */, boostExactMatches) : finalScore; const int updatedOutputWordIndex = outputShortcuts(&shortcutIt, outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h index b76b13971..417620e00 100644 --- a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h @@ -40,6 +40,8 @@ class DictionaryHeaderStructurePolicy { virtual void readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const = 0; + virtual bool shouldBoostExactMatches() const = 0; + protected: DictionaryHeaderStructurePolicy() {} diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index 783383450..e581a97c3 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -28,7 +28,8 @@ class DicTraverseSession; class Scoring { public: virtual int calculateFinalScore(const float compoundDistance, const int inputSize, - const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit) const = 0; + const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, + const bool boostExactMatches) const = 0; virtual bool getMostProbableString(const DicTraverseSession *const traverseSession, const int terminalSize, const float languageWeight, int *const outputCodePoints, int *const type, int *const freq) const = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index a44f9f0fc..1320c6560 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -146,6 +146,11 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mHasHistoricalInfoOfWords; } + AK_FORCE_INLINE bool shouldBoostExactMatches() const { + // TODO: Investigate better ways to handle exact matches for personalized dictionaries. + return !isDecayingDict(); + } + void readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index c777e7238..8b405e8de 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -50,14 +50,14 @@ class TypingScoring : public Scoring { AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, - const bool forceCommit) const { + const bool forceCommit, const bool boostExactMatches) const { const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + static_cast(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance; if (forceCommit) { score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD; } - if (ErrorTypeUtils::isExactMatch(containedErrorTypes)) { + if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { score += ScoringParams::EXACT_MATCH_PROMOTION; if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) { score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;