diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 3f64d07b2..25299948d 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -128,7 +128,7 @@ class DicNode { void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos, const int childrenCount) { mIsUsed = true; - mIsCachedForNextSuggestion = false; + mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; mDicNodeProperties.init( pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); // TODO: Move to dicNodeState? @@ -479,6 +479,11 @@ class DicNode { return mDicNodeProperties.getDepth(); } + // "Length" includes spaces. + inline uint16_t getTotalLength() const { + return getDepth() + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); + } + AK_FORCE_INLINE void dump(const char *tag) const { #if DEBUG_DICT DUMP_WORD_AND_SCORE(tag); diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index 0c57ca001..117f48f29 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -106,7 +106,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n // only used for typing return weighting->getSubstitutionCost(); case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordCost(traverseSession, dicNode); + return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG); case CT_MATCH: return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_COMPLETION: @@ -134,7 +134,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: return 0.0f; case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); + return weighting->getNewWordBigramLanguageCost( + traverseSession, parentDicNode, multiBigramMap); case CT_MATCH: return 0.0f; case CT_COMPLETION: @@ -146,7 +147,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); } case CT_NEW_WORD_SPACE_SUBSTITUTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); + return weighting->getNewWordBigramLanguageCost( + traverseSession, parentDicNode, multiBigramMap); case CT_INSERTION: return 0.0f; case CT_TRANSPOSITION: diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index 0d2745b40..781a7adbc 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -56,10 +56,10 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual float getNewWordCost(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode) const = 0; + virtual float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const = 0; - virtual float getNewWordBigramCost( + virtual float getNewWordBigramLanguageCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const = 0; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 17fa11082..a1c99182a 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -138,12 +138,12 @@ class TypingWeighting : public Weighting { return cost + weightedDistance; } - float getNewWordCost(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode) const { + float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); } - float getNewWordBigramCost(const DicTraverseSession *const traverseSession, + float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const { return DicNodeUtils::getBigramNodeImprobability(traverseSession->getBinaryDictionaryInfo(),