diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 974bb483b..34a646f80 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -381,6 +381,7 @@ typedef enum { CT_TRANSPOSITION, CT_COMPLETION, CT_TERMINAL, + CT_TERMINAL_INSERTION, // Create new word with space omission CT_NEW_WORD_SPACE_OMITTION, // Create new word with space substitution diff --git a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h index 90f75d0c6..1f4d2570e 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h @@ -31,6 +31,7 @@ #define PROF_TRANSPOSITION(profiler) profiler.profTransposition() #define PROF_NEARESTKEY(profiler) profiler.profNearestKey() #define PROF_TERMINAL(profiler) profiler.profTerminal() +#define PROF_TERMINAL_INSERTION(profiler) profiler.profTerminalInsertion() #define PROF_NEW_WORD(profiler) profiler.profNewWord() #define PROF_NEW_WORD_BIGRAM(profiler) profiler.profNewWordBigram() #define PROF_NODE_RESET(profiler) profiler.reset() @@ -47,6 +48,7 @@ #define PROF_TRANSPOSITION(profiler) #define PROF_NEARESTKEY(profiler) #define PROF_TERMINAL(profiler) +#define PROF_TERMINAL_INSERTION(profiler) #define PROF_NEW_WORD(profiler) #define PROF_NEW_WORD_BIGRAM(profiler) #define PROF_NODE_RESET(profiler) @@ -62,7 +64,7 @@ class DicNodeProfiler { : mProfOmission(0), mProfInsertion(0), mProfTransposition(0), mProfAdditionalProximity(0), mProfSubstitution(0), mProfSpaceSubstitution(0), mProfSpaceOmission(0), - mProfMatch(0), mProfCompletion(0), mProfTerminal(0), + mProfMatch(0), mProfCompletion(0), mProfTerminal(0), mProfTerminalInsertion(0), mProfNearestKey(0), mProfNewWord(0), mProfNewWordBigram(0) {} int mProfOmission; @@ -75,6 +77,7 @@ class DicNodeProfiler { int mProfMatch; int mProfCompletion; int mProfTerminal; + int mProfTerminalInsertion; int mProfNearestKey; int mProfNewWord; int mProfNewWordBigram; @@ -123,6 +126,10 @@ class DicNodeProfiler { ++mProfTerminal; } + void profTerminalInsertion() { + ++mProfTerminalInsertion; + } + void profNewWord() { ++mProfNewWord; } diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index 117f48f29..58729229f 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -50,6 +50,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_TERMINAL: PROF_TERMINAL(node->mProfiler); return; + case CT_TERMINAL_INSERTION: + PROF_TERMINAL_INSERTION(node->mProfiler); + return; case CT_NEW_WORD_SPACE_SUBSTITUTION: PROF_SPACE_SUBSTITUTION(node->mProfiler); return; @@ -113,6 +116,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return weighting->getCompletionCost(traverseSession, dicNode); case CT_TERMINAL: return weighting->getTerminalSpatialCost(traverseSession, dicNode); + case CT_TERMINAL_INSERTION: + return weighting->getTerminalInsertionCost(traverseSession, dicNode); case CT_NEW_WORD_SPACE_SUBSTITUTION: return weighting->getSpaceSubstitutionCost(traverseSession, dicNode); case CT_INSERTION: @@ -146,6 +151,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n traverseSession->getBinaryDictionaryInfo(), dicNode, multiBigramMap); return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); } + case CT_TERMINAL_INSERTION: + return 0.0f; case CT_NEW_WORD_SPACE_SUBSTITUTION: return weighting->getNewWordBigramLanguageCost( traverseSession, parentDicNode, multiBigramMap); @@ -163,9 +170,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_OMISSION: return 0; case CT_ADDITIONAL_PROXIMITY: - return 0; + return 0; /* 0 because CT_MATCH will be called */ case CT_SUBSTITUTION: - return 0; + return 0; /* 0 because CT_MATCH will be called */ case CT_NEW_WORD_SPACE_OMITTION: return 0; case CT_MATCH: @@ -174,12 +181,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 1; case CT_TERMINAL: return 0; + case CT_TERMINAL_INSERTION: + return 1; case CT_NEW_WORD_SPACE_SUBSTITUTION: return 1; case CT_INSERTION: - return 2; + return 2; /* look ahead + skip the current char */ case CT_TRANSPOSITION: - return 2; + return 2; /* look ahead + skip the current char */ default: return 0; } diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index 781a7adbc..2d49e98a6 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -67,6 +67,10 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; + virtual float getTerminalInsertionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual float getTerminalLanguageCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, float dicNodeLanguageImprobability) const = 0; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index d6383b958..73e9714bd 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -365,17 +365,17 @@ void Suggest::processTerminalDicNode( if (!dicNode->isTerminalWordNode()) { return; } - if (TRAVERSAL->needsToTraverseAllUserInput() - && dicNode->getInputIndex(0) < traverseSession->getInputSize()) { - return; - } - if (dicNode->shouldBeFilterdBySafetyNetForBigram()) { return; } // Create a non-cached node here. DicNode terminalDicNode; DicNodeUtils::initByCopy(dicNode, &terminalDicNode); + if (TRAVERSAL->needsToTraverseAllUserInput() + && dicNode->getInputIndex(0) < traverseSession->getInputSize()) { + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0, + &terminalDicNode, traverseSession->getMultiBigramMap()); + } Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0, &terminalDicNode, traverseSession->getMultiBigramMap()); traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode); diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index a8f797c5c..4157f411e 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -34,6 +34,7 @@ const float ScoringParams::OMISSION_COST = 0.458f; const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f; const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f; const float ScoringParams::INSERTION_COST = 0.730f; +const float ScoringParams::TERMINAL_INSERTION_COST = 0.93f; const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f; const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.70f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index 4ebcc7dc3..a743b4d81 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -42,6 +42,7 @@ class ScoringParams { static const float OMISSION_COST_SAME_CHAR; static const float OMISSION_COST_FIRST_CHAR; static const float INSERTION_COST; + static const float TERMINAL_INSERTION_COST; static const float INSERTION_COST_SAME_CHAR; static const float INSERTION_COST_PROXIMITY_CHAR; static const float INSERTION_COST_FIRST_CHAR; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index e4c69d1f6..408b12ae9 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -44,6 +44,7 @@ ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType, break; case CT_SUBSTITUTION: case CT_INSERTION: + case CT_TERMINAL_INSERTION: case CT_TRANSPOSITION: return ET_EDIT_CORRECTION; case CT_NEW_WORD_SPACE_OMITTION: diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 1bb160738..7cddb0882 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -175,6 +175,15 @@ class TypingWeighting : public Weighting { return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; } + float getTerminalInsertionCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { + const int inputIndex = dicNode->getInputIndex(0); + const int inputSize = traverseSession->getInputSize(); + ASSERT(inputIndex < inputSize); + // TODO: Implement more efficient logic + return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex); + } + AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { return false; }