diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h index 1c4061fd8..2d2e19501 100644 --- a/native/jni/src/binary_format.h +++ b/native/jni/src/binary_format.h @@ -92,6 +92,7 @@ class BinaryFormat { const int unigramProbability, const int bigramProbability); static int getProbability(const int position, const std::map *bigramMap, const uint8_t *bigramFilter, const int unigramProbability); + static float getMultiWordCostMultiplier(const uint8_t *const dict); // Flags for special processing // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or @@ -241,6 +242,17 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t * return ((msb & 0x7F) << 8) | dict[(*pos)++]; } +inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) { + const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE"); + if (headerValue == S_INT_MIN) { + return 1.0f; + } + if (headerValue <= 0) { + return static_cast(MAX_VALUE_FOR_WEIGHTING); + } + return 100.0f / static_cast(headerValue); +} + inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) { return dict[(*pos)++]; } diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index a7b023a75..6ef9f414b 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -424,10 +424,9 @@ typedef enum { CT_OMISSION, CT_INSERTION, CT_TRANSPOSITION, - CT_SPACE_SUBSTITUTION, - CT_SPACE_OMISSION, CT_COMPLETION, CT_TERMINAL, - CT_NEW_WORD, + CT_NEW_WORD_SPACE_OMITTION, + CT_NEW_WORD_SPACE_SUBSTITUTION, } CorrectionType; #endif // LATINIME_DEFINES_H diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index e62b70423..b9c0b8129 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -38,7 +38,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: PROF_SUBSTITUTION(node->mProfiler); return; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: PROF_NEW_WORD(node->mProfiler); return; case CT_MATCH: @@ -50,7 +50,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_TERMINAL: PROF_TERMINAL(node->mProfiler); return; - case CT_SPACE_SUBSTITUTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: PROF_SPACE_SUBSTITUTION(node->mProfiler); return; case CT_INSERTION: @@ -107,16 +107,16 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: // only used for typing return weighting->getSubstitutionCost(); - case CT_NEW_WORD: - return weighting->getNewWordCost(dicNode); + case CT_NEW_WORD_SPACE_OMITTION: + return weighting->getNewWordCost(traverseSession, dicNode); case CT_MATCH: return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_COMPLETION: return weighting->getCompletionCost(traverseSession, dicNode); case CT_TERMINAL: return weighting->getTerminalSpatialCost(traverseSession, dicNode); - case CT_SPACE_SUBSTITUTION: - return weighting->getSpaceSubstitutionCost(); + case CT_NEW_WORD_SPACE_SUBSTITUTION: + return weighting->getSpaceSubstitutionCost(traverseSession, dicNode); case CT_INSERTION: return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode); case CT_TRANSPOSITION: @@ -135,7 +135,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 0.0f; case CT_SUBSTITUTION: return 0.0f; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); case CT_MATCH: return 0.0f; @@ -147,8 +147,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n traverseSession->getOffsetDict(), dicNode, bigramCacheMap); return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); } - case CT_SPACE_SUBSTITUTION: - return 0.0f; + case CT_NEW_WORD_SPACE_SUBSTITUTION: + return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); case CT_INSERTION: return 0.0f; case CT_TRANSPOSITION: @@ -168,7 +168,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: // Should return true? return false; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: return false; case CT_MATCH: return false; @@ -176,7 +176,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return false; case CT_TERMINAL: return false; - case CT_SPACE_SUBSTITUTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: return false; case CT_INSERTION: return true; @@ -197,7 +197,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return false; case CT_SUBSTITUTION: return false; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: return false; case CT_MATCH: return weighting->isProximityDicNode(traverseSession, dicNode); @@ -205,7 +205,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return false; case CT_TERMINAL: return false; - case CT_SPACE_SUBSTITUTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: return false; case CT_INSERTION: return false; @@ -224,7 +224,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 0; case CT_SUBSTITUTION: return 0; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: return 0; case CT_MATCH: return 1; @@ -232,7 +232,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 0; case CT_TERMINAL: return 0; - case CT_SPACE_SUBSTITUTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: return 1; case CT_INSERTION: return 2; diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index b92dbe278..bce479c51 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -56,7 +56,8 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual float getNewWordCost(const DicNode *const dicNode) const = 0; + virtual float getNewWordCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; virtual float getNewWordBigramCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, @@ -76,7 +77,8 @@ class Weighting { virtual float getSubstitutionCost() const = 0; - virtual float getSpaceSubstitutionCost() const = 0; + virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; Weighting() {} virtual ~Weighting() {} diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 5b783a2ba..3c44db21c 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -16,6 +16,7 @@ #include "suggest/core/session/dic_traverse_session.h" +#include "binary_format.h" #include "defines.h" #include "dictionary.h" #include "dic_traverse_wrapper.h" @@ -63,6 +64,7 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer; void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, int prevWordLength) { mDictionary = dictionary; + mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict()); if (!prevWord) { mPrevWordPos = NOT_VALID_WORD; return; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index fe0527639..d9c2a51d0 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -36,7 +36,8 @@ class DicTraverseSession { AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), mDictionary(0), mDicNodesCache(), mBigramCacheMap(), - mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1) { + mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), + mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. } @@ -52,6 +53,7 @@ class DicTraverseSession { const int maxPointerCount); void resetCache(const int nextActiveCacheSize, const int maxWords); + // TODO: Remove const uint8_t *getOffsetDict() const; int getDictFlags() const; @@ -150,6 +152,10 @@ class DicTraverseSession { return mProximityInfoStates[0].touchPositionCorrectionEnabled(); } + float getMultiWordCostMultiplier() const { + return mMultiWordCostMultiplier; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); // threshold to start caching @@ -170,6 +176,11 @@ class DicTraverseSession { int mInputSize; bool mPartiallyCommited; int mMaxPointerCount; + + ///////////////////////////////// + // Configuration per dictionary + float mMultiWordCostMultiplier; + }; } // namespace latinime #endif // LATINIME_DIC_TRAVERSE_SESSION_H diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 67d351fa1..9de2cd2e2 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -33,16 +33,9 @@ namespace latinime { // Initialization of class constants. -const int Suggest::LOOKAHEAD_DIC_NODES_CACHE_SIZE = 25; const int Suggest::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f; -const float Suggest::AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD = 0.6f; - -const bool Suggest::CORRECT_SPACE_OMISSION = true; -const bool Suggest::CORRECT_TRANSPOSITION = true; -const bool Suggest::CORRECT_INSERTION = true; -const bool Suggest::CORRECT_OMISSION_G = true; /** * Returns a set of suggestions for the given input touch points. The commitPoint argument indicates @@ -270,12 +263,8 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { // latest touch point yet. These are needed to apply look-ahead correction operations // that require special handling of the latest touch point. For example, with insertions // (e.g., "thiis" -> "this") the latest touch point should not be consumed at all. - if (CORRECT_TRANSPOSITION) { - processDicNodeAsTransposition(traverseSession, &dicNode); - } - if (CORRECT_INSERTION) { - processDicNodeAsInsertion(traverseSession, &dicNode); - } + processDicNodeAsTransposition(traverseSession, &dicNode); + processDicNodeAsInsertion(traverseSession, &dicNode); } else { // !isLookAheadCorrection // Only consider typing error corrections if the normalized compound distance is // below a spatial distance threshold. @@ -531,13 +520,10 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode DicNode newDicNode; DicNodeUtils::initAsRootWithPreviousWord(traverseSession->getDicRootPos(), traverseSession->getOffsetDict(), dicNode, &newDicNode); - Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_NEW_WORD, traverseSession, dicNode, + const CorrectionType correctionType = spaceSubstitution ? + CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION; + Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode, &newDicNode, traverseSession->getBigramCacheMap()); - if (spaceSubstitution) { - // Merge this with CT_NEW_WORD - Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SPACE_SUBSTITUTION, - traverseSession, 0, &newDicNode, 0 /* bigramCacheMap */); - } traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); } } // namespace latinime diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index becd6c1de..875cbe4e0 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -76,31 +76,16 @@ class Suggest : public SuggestInterface { void processDicNodeAsMatch(DicTraverseSession *traverseSession, DicNode *childDicNode) const; - // Dic nodes cache size for lookahead (autocompletion) - static const int LOOKAHEAD_DIC_NODES_CACHE_SIZE; - // Max characters to lookahead - static const int MAX_LOOKAHEAD; // Inputs longer than this will autocorrect if the suggestion is multi-word static const int MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT; static const int MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE; - // Base value for converting costs into scores (low so will not autocorrect without classifier) - static const float BASE_OUTPUT_SCORE; // Threshold for autocorrection classifier static const float AUTOCORRECT_CLASSIFICATION_THRESHOLD; - // Threshold for computing the language model feature for autocorrect classification - static const float AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD; - - // Typing error correction settings - static const bool CORRECT_SPACE_OMISSION; - static const bool CORRECT_TRANSPOSITION; - static const bool CORRECT_INSERTION; const Traversal *const TRAVERSAL; const Scoring *const SCORING; const Weighting *const WEIGHTING; - - static const bool CORRECT_OMISSION_G; }; } // namespace latinime #endif // LATINIME_SUGGEST_IMPL_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp b/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp index 66f8ba9fa..e7e40e34d 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp @@ -18,7 +18,7 @@ namespace latinime { const bool TypingTraversal::CORRECT_OMISSION = true; -const bool TypingTraversal::CORRECT_SPACE_SUBSTITUTION = true; -const bool TypingTraversal::CORRECT_SPACE_OMISSION = true; +const bool TypingTraversal::CORRECT_NEW_WORD_SPACE_SUBSTITUTION = true; +const bool TypingTraversal::CORRECT_NEW_WORD_SPACE_OMISSION = true; const TypingTraversal TypingTraversal::sInstance; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index f22029a2c..9f8347452 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -66,7 +66,7 @@ class TypingTraversal : public Traversal { AK_FORCE_INLINE bool isSpaceSubstitutionTerminal( const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { - if (!CORRECT_SPACE_SUBSTITUTION) { + if (!CORRECT_NEW_WORD_SPACE_SUBSTITUTION) { return false; } if (!canDoLookAheadCorrection(traverseSession, dicNode)) { @@ -80,7 +80,7 @@ class TypingTraversal : public Traversal { AK_FORCE_INLINE bool isSpaceOmissionTerminal( const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { - if (!CORRECT_SPACE_OMISSION) { + if (!CORRECT_NEW_WORD_SPACE_OMISSION) { return false; } const int inputSize = traverseSession->getInputSize(); @@ -173,8 +173,8 @@ class TypingTraversal : public Traversal { private: DISALLOW_COPY_AND_ASSIGN(TypingTraversal); static const bool CORRECT_OMISSION; - static const bool CORRECT_SPACE_SUBSTITUTION; - static const bool CORRECT_SPACE_OMISSION; + static const bool CORRECT_NEW_WORD_SPACE_SUBSTITUTION; + static const bool CORRECT_NEW_WORD_SPACE_OMISSION; static const TypingTraversal sInstance; TypingTraversal() {} diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 2dcee343f..74e4e34e4 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -128,10 +128,12 @@ class TypingWeighting : public Weighting { return cost + weightedDistance; } - float getNewWordCost(const DicNode *const dicNode) const { + float getNewWordCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { const bool isCapitalized = dicNode->isCapitalized(); - return isCapitalized ? + const float cost = isCapitalized ? ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD; + return cost * traverseSession->getMultiWordCostMultiplier(); } float getNewWordBigramCost( @@ -183,8 +185,13 @@ class TypingWeighting : public Weighting { return ScoringParams::SUBSTITUTION_COST; } - AK_FORCE_INLINE float getSpaceSubstitutionCost() const { - return ScoringParams::SPACE_SUBSTITUTION_COST; + AK_FORCE_INLINE float getSpaceSubstitutionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { + const bool isCapitalized = dicNode->isCapitalized(); + const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + (isCapitalized ? + ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD); + return cost * traverseSession->getMultiWordCostMultiplier(); } private: