diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp index 94d7c886f..f71d4c5f0 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp @@ -54,15 +54,18 @@ namespace latinime { current.swap(next); } - int maxUnigramProbability = NOT_A_PROBABILITY; + int maxProbability = NOT_A_PROBABILITY; for (const DicNode &dicNode : current) { if (!dicNode.isTerminalDicNode()) { continue; } + const WordAttributes wordAttributes = + dictionaryStructurePolicy->getWordAttributesInContext(dicNode.getPrevWordIds(), + dicNode.getWordId(), nullptr /* multiBigramMap */); // dicNode can contain case errors, accent errors, intentional omissions or digraphs. - maxUnigramProbability = std::max(maxUnigramProbability, dicNode.getUnigramProbability()); + maxProbability = std::max(maxProbability, wordAttributes.getProbability()); } - return maxUnigramProbability; + return maxProbability; } /* static */ void DictionaryUtils::processChildDicNodes( diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h index 8ddaa0514..6dfa7e314 100644 --- a/native/jni/src/suggest/core/policy/traversal.h +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -48,7 +48,8 @@ class Traversal { virtual int getTerminalCacheSize() const = 0; virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0; + virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode, + const int probability) const = 0; protected: Traversal() {} diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 66c87f04c..947d41f4b 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -21,6 +21,7 @@ #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/dictionary/digraph_utils.h" +#include "suggest/core/dictionary/word_attributes.h" #include "suggest/core/layout/proximity_info.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/core/policy/traversal.h" @@ -412,7 +413,11 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN */ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, const bool spaceSubstitution) const { - if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode)) { + const WordAttributes wordAttributes = + traverseSession->getDictionaryStructurePolicy()->getWordAttributesInContext( + dicNode->getPrevWordIds(), dicNode->getWordId(), + traverseSession->getMultiBigramMap()); + if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode, wordAttributes.getProbability())) { return; } diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index ed9df8eb3..b64ee8be4 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -161,9 +161,8 @@ class TypingTraversal : public Traversal { return true; } - AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode) const { - // TODO: Quit using unigram probability and use probability in the context. - const int probability = dicNode->getUnigramProbability(); + AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode, + const int probability) const { if (probability < ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY) { return false; }