diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 9f03e30d1..19f92cc0b 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -18,7 +18,6 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" -#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" namespace latinime { @@ -73,25 +72,12 @@ namespace latinime { if (dicNode->hasMultipleWords() && !dicNode->isValidMultipleWordSuggestion()) { return static_cast(MAX_VALUE_FOR_WEIGHTING); } - const int probability = getBigramNodeProbability(dictionaryStructurePolicy, dicNode, - multiBigramMap); + const int probability = dictionaryStructurePolicy->getProbabilityOfWordInContext( + dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap); // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. const float cost = static_cast(MAX_PROBABILITY - probability) / static_cast(MAX_PROBABILITY); return cost; } -/* static */ int DicNodeUtils::getBigramNodeProbability( - const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) { - const int unigramProbability = dicNode->getUnigramProbability(); - if (multiBigramMap) { - const int *const prevWordIds = dicNode->getPrevWordIds(); - return multiBigramMap->getBigramProbability(dictionaryStructurePolicy, - prevWordIds, dicNode->getWordId(), unigramProbability); - } - return dictionaryStructurePolicy->getProbability(unigramProbability, - NOT_A_PROBABILITY); -} - } // namespace latinime diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h index 56ff6e3d0..961a1c29d 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -46,10 +46,6 @@ class DicNodeUtils { DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils); // Max number of bigrams to look up static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500; - - static int getBigramNodeProbability( - const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const DicNode *const dicNode, MultiBigramMap *const multiBigramMap); }; } // namespace latinime #endif // LATINIME_DIC_NODE_UTILS_H diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index aeeb66f93..4e55418ae 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -29,6 +29,7 @@ namespace latinime { class DicNode; class DicNodeVector; class DictionaryHeaderStructurePolicy; +class MultiBigramMap; class NgramListener; class PrevWordsInfo; class UnigramProperty; @@ -56,6 +57,10 @@ class DictionaryStructureWithBufferPolicy { virtual int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const = 0; + virtual int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const = 0; + + // TODO: Remove virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0; virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 6480374df..88982e540 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -28,6 +28,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" @@ -117,6 +118,26 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, return getWordIdFromTerminalPtNodePos(ptNodePos); } +int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const { + if (wordId == NOT_A_WORD_ID) { + return NOT_A_PROBABILITY; + } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); + const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); + if (multiBigramMap) { + return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds, + wordId, ptNodeParams.getProbability()); + } + if (prevWordIds) { + const int probability = getProbabilityOfWord(prevWordIds, wordId); + if (probability != NOT_A_PROBABILITY) { + return probability; + } + } + return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); +} + int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, const int bigramProbability) const { if (mHeaderPolicy->isDecayingDict()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 562c219f4..06d704174 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -91,6 +91,9 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; + int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const; + int getProbability(const int unigramProbability, const int bigramProbability) const; int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index e0406ab07..80bbf47c0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -21,6 +21,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/session/prev_words_info.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" @@ -281,6 +282,27 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, return getWordIdFromTerminalPtNodePos(ptNodePos); } +int PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const { + if (wordId == NOT_A_WORD_ID) { + return NOT_A_PROBABILITY; + } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); + const PtNodeParams ptNodeParams = + mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + if (multiBigramMap) { + return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds, + wordId, ptNodeParams.getProbability()); + } + if (prevWordIds) { + const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId); + if (bigramProbability != NOT_A_PROBABILITY) { + return bigramProbability; + } + } + return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); +} + int PatriciaTriePolicy::getProbability(const int unigramProbability, const int bigramProbability) const { // Due to space constraints, the probability for bigrams is approximate - the lower the unigram diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 66df52779..a2d6b6fa6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -66,6 +66,9 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; + int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const; + int getProbability(const int unigramProbability, const int bigramProbability) const; int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 466c49952..6de3e5a81 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -20,6 +20,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" @@ -112,6 +113,28 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, return ptNodeParams.getTerminalId(); } +int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const { + // TODO: Quit using MultiBigramMap. + if (wordId == NOT_A_WORD_ID) { + return NOT_A_PROBABILITY; + } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); + const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); + if (multiBigramMap) { + return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds, + wordId, ptNodeParams.getProbability()); + } + if (prevWordIds) { + const int probability = getProbabilityOfWord(prevWordIds, wordId); + if (probability != NOT_A_PROBABILITY) { + return probability; + } + } + return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); +} + int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, const int bigramProbability) const { if (mHeaderPolicy->isDecayingDict()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index 0b8eec40b..c9df9df4b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -68,6 +68,9 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; + int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const; + int getProbability(const int unigramProbability, const int bigramProbability) const; int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;