From 9559dd2e30de288a9ff7069bfc59f8500b949a88 Mon Sep 17 00:00:00 2001 From: Tom Ouyang Date: Tue, 16 Apr 2013 16:34:49 -0700 Subject: [PATCH] Improve bigram frequency lookup Bug: 8592527 Change-Id: I1908bcb552279b9acb140fe4d8d26b10ed9eda72 --- native/jni/src/binary_format.h | 72 +++++++++++++++ native/jni/src/defines.h | 9 ++ native/jni/src/multi_bigram_map.h | 89 +++++++++++++++++++ .../suggest/core/dicnode/dic_node_utils.cpp | 83 +++-------------- .../src/suggest/core/dicnode/dic_node_utils.h | 12 +-- .../jni/src/suggest/core/policy/weighting.cpp | 15 ++-- .../jni/src/suggest/core/policy/weighting.h | 8 +- .../core/session/dic_traverse_session.cpp | 2 +- .../core/session/dic_traverse_session.h | 8 +- native/jni/src/suggest/core/suggest.cpp | 20 +++-- .../policyimpl/typing/typing_weighting.h | 5 +- 11 files changed, 217 insertions(+), 106 deletions(-) create mode 100644 native/jni/src/multi_bigram_map.h diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h index 432a56b7f..06f50dc7f 100644 --- a/native/jni/src/binary_format.h +++ b/native/jni/src/binary_format.h @@ -23,6 +23,7 @@ #include "bloom_filter.h" #include "char_utils.h" +#include "hash_map_compat.h" namespace latinime { @@ -93,7 +94,13 @@ class BinaryFormat { const int unigramProbability, const int bigramProbability); static int getProbability(const int position, const std::map *bigramMap, const uint8_t *bigramFilter, const int unigramProbability); + static int getBigramProbabilityFromHashMap(const int position, + const hash_map_compat *bigramMap, const int unigramProbability); static float getMultiWordCostMultiplier(const uint8_t *const dict); + static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, + hash_map_compat *bigramMap); + static int getBigramProbability(const uint8_t *const root, int position, + const int nextPosition, const int unigramProbability); // Flags for special processing // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or @@ -105,6 +112,8 @@ class BinaryFormat { private: DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat); + static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); + static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80; @@ -687,5 +696,68 @@ inline int BinaryFormat::getProbability(const int position, const std::map *bigramMap, const int unigramProbability) { + if (!bigramMap) return backoff(unigramProbability); + const hash_map_compat::const_iterator bigramProbabilityIt = bigramMap->find(position); + if (bigramProbabilityIt != bigramMap->end()) { + const int bigramProbability = bigramProbabilityIt->second; + return computeProbabilityForBigram(unigramProbability, bigramProbability); + } + return backoff(unigramProbability); +} + +AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap( + const uint8_t *const root, int position, hash_map_compat *bigramMap) { + position = getBigramListPositionForWordPosition(root, position); + if (0 == position) return; + + uint8_t bigramFlags; + do { + bigramFlags = getFlagsAndForwardPointer(root, &position); + const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; + const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags, + &position); + (*bigramMap)[bigramPos] = probability; + } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); +} + +AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position, + const int nextPosition, const int unigramProbability) { + position = getBigramListPositionForWordPosition(root, position); + if (0 == position) return backoff(unigramProbability); + + uint8_t bigramFlags; + do { + bigramFlags = getFlagsAndForwardPointer(root, &position); + const int bigramPos = getAttributeAddressAndForwardPointer( + root, bigramFlags, &position); + if (bigramPos == nextPosition) { + const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; + return computeProbabilityForBigram(unigramProbability, bigramProbability); + } + } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + return backoff(unigramProbability); +} + +// Returns a pointer to the start of the bigram list. +AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition( + const uint8_t *const root, int position) { + if (NOT_VALID_WORD == position) return 0; + const uint8_t flags = getFlagsAndForwardPointer(root, &position); + if (!(flags & FLAG_HAS_BIGRAMS)) return 0; + if (flags & FLAG_HAS_MULTIPLE_CHARS) { + position = skipOtherCharacters(root, position); + } else { + getCodePointAndForwardPointer(root, &position); + } + position = skipProbability(flags, position); + position = skipChildrenPosition(flags, position); + position = skipShortcuts(root, flags, position); + return position; +} + } // namespace latinime #endif // LATINIME_BINARY_FORMAT_H diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index d3b351f81..eb59744f6 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -379,6 +379,15 @@ static inline void prof_out(void) { #error "BIGRAM_FILTER_MODULO is larger than BIGRAM_FILTER_BYTE_SIZE" #endif +// Max number of bigram maps (previous word contexts) to be cached. Increasing this number could +// improve bigram lookup speed for multi-word suggestions, but at the cost of more memory usage. +// Also, there are diminishing returns since the most frequently used bigrams are typically near +// the beginning of the input and are thus the first ones to be cached. Note that these bigrams +// are reset for each new composing word. +#define MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP 25 +// Most common previous word contexts currently have 100 bigrams +#define DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP 100 + template AK_FORCE_INLINE const T &min(const T &a, const T &b) { return a < b ? a : b; } template AK_FORCE_INLINE const T &max(const T &a, const T &b) { return a > b ? a : b; } diff --git a/native/jni/src/multi_bigram_map.h b/native/jni/src/multi_bigram_map.h new file mode 100644 index 000000000..7e1b6301f --- /dev/null +++ b/native/jni/src/multi_bigram_map.h @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_MULTI_BIGRAM_MAP_H +#define LATINIME_MULTI_BIGRAM_MAP_H + +#include +#include + +#include "defines.h" +#include "binary_format.h" +#include "hash_map_compat.h" + +namespace latinime { + +// Class for caching bigram maps for multiple previous word contexts. This is useful since the +// algorithm needs to look up the set of bigrams for every word pair that occurs in every +// multi-word suggestion. +class MultiBigramMap { + public: + MultiBigramMap() : mBigramMaps() {} + ~MultiBigramMap() {} + + // Look up the bigram probability for the given word pair from the cached bigram maps. + // Also caches the bigrams if there is space remaining and they have not been cached already. + int getBigramProbability(const uint8_t *const dicRoot, const int wordPosition, + const int nextWordPosition, const int unigramProbability) { + hash_map_compat::const_iterator mapPosition = + mBigramMaps.find(wordPosition); + if (mapPosition != mBigramMaps.end()) { + return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability); + } + if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { + addBigramsForWordPosition(dicRoot, wordPosition); + return mBigramMaps[wordPosition].getBigramProbability( + nextWordPosition, unigramProbability); + } + return BinaryFormat::getBigramProbability( + dicRoot, wordPosition, nextWordPosition, unigramProbability); + } + + void clear() { + mBigramMaps.clear(); + } + + private: + DISALLOW_COPY_AND_ASSIGN(MultiBigramMap); + + class BigramMap { + public: + BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP) {} + ~BigramMap() {} + + void init(const uint8_t *const dicRoot, int position) { + BinaryFormat::fillBigramProbabilityToHashMap(dicRoot, position, &mBigramMap); + } + + inline int getBigramProbability(const int nextWordPosition, const int unigramProbability) + const { + return BinaryFormat::getBigramProbabilityFromHashMap( + nextWordPosition, &mBigramMap, unigramProbability); + } + + private: + // Note: Default copy constructor needed for use in hash_map. + hash_map_compat mBigramMap; + }; + + void addBigramsForWordPosition(const uint8_t *const dicRoot, const int position) { + mBigramMaps[position].init(dicRoot, position); + } + + hash_map_compat mBigramMaps; +}; +} // namespace latinime +#endif // LATINIME_MULTI_BIGRAM_MAP_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index a04812279..5357c3773 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -21,6 +21,7 @@ #include "dic_node.h" #include "dic_node_utils.h" #include "dic_node_vector.h" +#include "multi_bigram_map.h" #include "proximity_info.h" #include "proximity_info_state.h" @@ -191,11 +192,11 @@ namespace latinime { * Computes the combined bigram / unigram cost for the given dicNode. */ /* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat *bigramCacheMap) { + const DicNode *const node, MultiBigramMap *multiBigramMap) { if (node->isImpossibleBigramWord()) { return static_cast(MAX_VALUE_FOR_WEIGHTING); } - const int probability = getBigramNodeProbability(dicRoot, node, bigramCacheMap); + const int probability = getBigramNodeProbability(dicRoot, node, multiBigramMap); // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. const float cost = static_cast(MAX_PROBABILITY - probability) / static_cast(MAX_PROBABILITY); @@ -203,83 +204,25 @@ namespace latinime { } /* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat *bigramCacheMap) { + const DicNode *const node, MultiBigramMap *multiBigramMap) { const int unigramProbability = node->getProbability(); - const int encodedDiffOfBigramProbability = - getBigramNodeEncodedDiffProbability(dicRoot, node, bigramCacheMap); - if (NOT_A_PROBABILITY == encodedDiffOfBigramProbability) { + const int wordPos = node->getPos(); + const int prevWordPos = node->getPrevWordPos(); + if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) { + // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD. return backoff(unigramProbability); } - return BinaryFormat::computeProbabilityForBigram( - unigramProbability, encodedDiffOfBigramProbability); + if (multiBigramMap) { + return multiBigramMap->getBigramProbability( + dicRoot, prevWordPos, wordPos, unigramProbability); + } + return BinaryFormat::getBigramProbability(dicRoot, prevWordPos, wordPos, unigramProbability); } /////////////////////////////////////// // Bigram / Unigram dictionary utils // /////////////////////////////////////// -/* static */ int16_t DicNodeUtils::getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat *bigramCacheMap) { - const int wordPos = node->getPos(); - const int prevWordPos = node->getPrevWordPos(); - return getBigramProbability(dicRoot, prevWordPos, wordPos, bigramCacheMap); -} - -// TODO: Move this to BigramDictionary -/* static */ int16_t DicNodeUtils::getBigramProbability(const uint8_t *const dicRoot, int pos, - const int nextPos, hash_map_compat *bigramCacheMap) { - // TODO: this is painfully slow compared to the method used in the previous version of the - // algorithm. Switch to that method. - if (NOT_VALID_WORD == pos) return NOT_A_PROBABILITY; - if (NOT_VALID_WORD == nextPos) return NOT_A_PROBABILITY; - - // Create a hash code for the given node pair (based on Josh Bloch's effective Java). - // TODO: Use a real hash map data structure that deals with collisions. - int hash = 17; - hash = hash * 31 + pos; - hash = hash * 31 + nextPos; - - hash_map_compat::const_iterator mapPos = bigramCacheMap->find(hash); - if (mapPos != bigramCacheMap->end()) { - return mapPos->second; - } - if (NOT_VALID_WORD == pos) { - return NOT_A_PROBABILITY; - } - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos); - if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) { - return NOT_A_PROBABILITY; - } - if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) { - BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos); - } else { - pos = BinaryFormat::skipOtherCharacters(dicRoot, pos); - } - pos = BinaryFormat::skipChildrenPosition(flags, pos); - pos = BinaryFormat::skipProbability(flags, pos); - uint8_t bigramFlags; - int count = 0; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos); - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(dicRoot, - bigramFlags, &pos); - if (bigramPos == nextPos) { - const int16_t probability = BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - if (static_cast(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) { - (*bigramCacheMap)[hash] = probability; - } - return probability; - } - count++; - } while ((BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags) - && count < MAX_BIGRAMS_CONSIDERED_PER_CONTEXT); - if (static_cast(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) { - // TODO: does this -1 mean NOT_VALID_WORD? - (*bigramCacheMap)[hash] = -1; - } - return NOT_A_PROBABILITY; -} - /* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, const int nodeCodePoint) { if (!pInfoState) { diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h index 2e6361d87..5bc542d05 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -21,7 +21,6 @@ #include #include "defines.h" -#include "hash_map_compat.h" namespace latinime { @@ -29,6 +28,7 @@ class DicNode; class DicNodeVector; class ProximityInfo; class ProximityInfoState; +class MultiBigramMap; class DicNodeUtils { public: @@ -42,7 +42,7 @@ class DicNodeUtils { static void getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot, DicNodeVector *childDicNodes); static float getBigramNodeImprobability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat *const bigramCacheMap); + const DicNode *const node, MultiBigramMap *const multiBigramMap); static bool isDicNodeFilteredOut(const int nodeCodePoint, const ProximityInfo *const pInfo, const std::vector *const codePointsFilter); // TODO: Move to private @@ -57,15 +57,11 @@ class DicNodeUtils { private: DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils); - // Max cache size for the space omission error correction bigram lookup - static const int MAX_BIGRAM_MAP_SIZE = 20000; // Max number of bigrams to look up static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500; static int getBigramNodeProbability(const uint8_t *const dicRoot, const DicNode *const node, - hash_map_compat *bigramCacheMap); - static int16_t getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat *bigramCacheMap); + MultiBigramMap *multiBigramMap); static void createAndGetPassingChildNode(DicNode *dicNode, const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, DicNodeVector *childDicNodes); static void createAndGetAllLeavingChildNodes(DicNode *dicNode, const uint8_t *const dicRoot, @@ -76,8 +72,6 @@ class DicNodeUtils { const int terminalDepth, const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, const std::vector *const codePointsFilter, const ProximityInfo *const pInfo, DicNodeVector *childDicNodes); - static int16_t getBigramProbability(const uint8_t *const dicRoot, int pos, const int nextPos, - hash_map_compat *bigramCacheMap); // TODO: Move to proximity info static bool isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex, diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index 4912b22f2..d01531f07 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -18,7 +18,6 @@ #include "char_utils.h" #include "defines.h" -#include "hash_map_compat.h" #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_profiler.h" #include "suggest/core/dicnode/dic_node_utils.h" @@ -26,6 +25,8 @@ namespace latinime { +class MultiBigramMap; + static inline void profile(const CorrectionType correctionType, DicNode *const node) { #if DEBUG_DICT switch (correctionType) { @@ -71,14 +72,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n /* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting, const CorrectionType correctionType, const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, DicNode *const dicNode, - hash_map_compat *const bigramCacheMap) { + MultiBigramMap *const multiBigramMap) { const int inputSize = traverseSession->getInputSize(); DicNode_InputStateG inputStateG; inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default const float spatialCost = Weighting::getSpatialCost(weighting, correctionType, traverseSession, parentDicNode, dicNode, &inputStateG); const float languageCost = Weighting::getLanguageCost(weighting, correctionType, - traverseSession, parentDicNode, dicNode, bigramCacheMap); + traverseSession, parentDicNode, dicNode, multiBigramMap); const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession, parentDicNode, dicNode); profile(correctionType, dicNode); @@ -127,14 +128,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n /* static */ float Weighting::getLanguageCost(const Weighting *const weighting, const CorrectionType correctionType, const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode, - hash_map_compat *const bigramCacheMap) { + MultiBigramMap *const multiBigramMap) { switch(correctionType) { case CT_OMISSION: return 0.0f; case CT_SUBSTITUTION: return 0.0f; case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); + return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); case CT_MATCH: return 0.0f; case CT_COMPLETION: @@ -142,11 +143,11 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_TERMINAL: { const float languageImprobability = DicNodeUtils::getBigramNodeImprobability( - traverseSession->getOffsetDict(), dicNode, bigramCacheMap); + traverseSession->getOffsetDict(), dicNode, multiBigramMap); return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); } case CT_NEW_WORD_SPACE_SUBSTITUTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); + return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); case CT_INSERTION: return 0.0f; case CT_TRANSPOSITION: diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index 6e740d9d6..0d2745b40 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -18,13 +18,13 @@ #define LATINIME_WEIGHTING_H #include "defines.h" -#include "hash_map_compat.h" namespace latinime { class DicNode; class DicTraverseSession; struct DicNode_InputStateG; +class MultiBigramMap; class Weighting { public: @@ -32,7 +32,7 @@ class Weighting { const CorrectionType correctionType, const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, DicNode *const dicNode, - hash_map_compat *const bigramCacheMap); + MultiBigramMap *const multiBigramMap); protected: virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, @@ -61,7 +61,7 @@ class Weighting { virtual float getNewWordBigramCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, - hash_map_compat *const bigramCacheMap) const = 0; + MultiBigramMap *const multiBigramMap) const = 0; virtual float getCompletionCost( const DicTraverseSession *const traverseSession, @@ -97,7 +97,7 @@ class Weighting { static float getLanguageCost(const Weighting *const weighting, const CorrectionType correctionType, const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode, - hash_map_compat *const bigramCacheMap); + MultiBigramMap *const multiBigramMap); // TODO: Move to TypingWeighting and GestureWeighting? static int getForwardInputCount(const CorrectionType correctionType); }; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index b3d47326d..51165858b 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -100,7 +100,7 @@ int DicTraverseSession::getDictFlags() const { void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { mDicNodesCache.reset(nextActiveCacheSize, maxWords); - mBigramCacheMap.clear(); + mMultiBigramMap.clear(); mPartiallyCommited = false; } diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index d9c2a51d0..d88be5b88 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -21,8 +21,8 @@ #include #include "defines.h" -#include "hash_map_compat.h" #include "jni.h" +#include "multi_bigram_map.h" #include "proximity_info_state.h" #include "suggest/core/dicnode/dic_nodes_cache.h" @@ -35,7 +35,7 @@ class DicTraverseSession { public: AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), - mDictionary(0), mDicNodesCache(), mBigramCacheMap(), + mDictionary(0), mDicNodesCache(), mMultiBigramMap(), mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. @@ -67,7 +67,7 @@ class DicTraverseSession { // TODO: Use proper parameter when changed int getDicRootPos() const { return 0; } DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } - hash_map_compat *getBigramCacheMap() { return &mBigramCacheMap; } + MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; } const ProximityInfoState *getProximityInfoState(int id) const { return &mProximityInfoStates[id]; } @@ -170,7 +170,7 @@ class DicTraverseSession { DicNodesCache mDicNodesCache; // Temporary cache for bigram frequencies - hash_map_compat mBigramCacheMap; + MultiBigramMap mMultiBigramMap; ProximityInfoState mProximityInfoStates[MAX_POINTER_COUNT_G]; int mInputSize; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 4f94a9a3b..3221dee9c 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -359,7 +359,7 @@ void Suggest::processTerminalDicNode( DicNode terminalDicNode; DicNodeUtils::initByCopy(dicNode, &terminalDicNode); Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0, - &terminalDicNode, traverseSession->getBigramCacheMap()); + &terminalDicNode, traverseSession->getMultiBigramMap()); traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode); } @@ -391,8 +391,10 @@ void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession, void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession, DicNode *dicNode, DicNode *childDicNode) const { + // Note: Most types of corrections don't need to look up the bigram information since they do + // not treat the node as a terminal. There is no need to pass the bigram map in these cases. Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY, - traverseSession, dicNode, childDicNode, 0 /* bigramCacheMap */); + traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */); weightChildNode(traverseSession, childDicNode); processExpandedDicNode(traverseSession, childDicNode); } @@ -400,7 +402,7 @@ void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traver void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession, DicNode *dicNode, DicNode *childDicNode) const { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession, - dicNode, childDicNode, 0 /* bigramCacheMap */); + dicNode, childDicNode, 0 /* multiBigramMap */); weightChildNode(traverseSession, childDicNode); processExpandedDicNode(traverseSession, childDicNode); } @@ -432,7 +434,7 @@ void Suggest::processDicNodeAsOmission( DicNode *const childDicNode = childDicNodes[i]; // Treat this word as omission Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, - dicNode, childDicNode, 0 /* bigramCacheMap */); + dicNode, childDicNode, 0 /* multiBigramMap */); weightChildNode(traverseSession, childDicNode); if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) { @@ -456,7 +458,7 @@ void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession, for (int i = 0; i < size; i++) { DicNode *const childDicNode = childDicNodes[i]; Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession, - dicNode, childDicNode, 0 /* bigramCacheMap */); + dicNode, childDicNode, 0 /* multiBigramMap */); processExpandedDicNode(traverseSession, childDicNode); } } @@ -481,7 +483,7 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession, for (int j = 0; j < childSize2; j++) { DicNode *const childDicNode2 = childDicNodes2[j]; Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION, - traverseSession, childDicNodes1[i], childDicNode2, 0 /* bigramCacheMap */); + traverseSession, childDicNodes1[i], childDicNode2, 0 /* multiBigramMap */); processExpandedDicNode(traverseSession, childDicNode2); } } @@ -496,10 +498,10 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN const int inputSize = traverseSession->getInputSize(); if (dicNode->isCompletion(inputSize)) { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession, - 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */); + 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); } else { // completion Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession, - 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */); + 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); } } @@ -520,7 +522,7 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode const CorrectionType correctionType = spaceSubstitution ? CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION; Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode, - &newDicNode, traverseSession->getBigramCacheMap()); + &newDicNode, traverseSession->getMultiBigramMap()); traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 9efcc17fe..e6fa1bdc4 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -28,6 +28,7 @@ namespace latinime { class DicNode; struct DicNode_InputStateG; +class MultiBigramMap; class TypingWeighting : public Weighting { public: @@ -136,9 +137,9 @@ class TypingWeighting : public Weighting { float getNewWordBigramCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, - hash_map_compat *const bigramCacheMap) const { + MultiBigramMap *const multiBigramMap) const { return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(), - dicNode, bigramCacheMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; + dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; } float getCompletionCost(const DicTraverseSession *const traverseSession,