Merge "Use WordIdArrayView for prevWordIds."
This commit is contained in:
commit
1605630cf9
20 changed files with 133 additions and 138 deletions
|
@ -105,7 +105,7 @@ class DicNode {
|
|||
}
|
||||
|
||||
// Init for root with prevWordIds which is used for n-gram
|
||||
void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordIds) {
|
||||
void initAsRoot(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) {
|
||||
mIsCachedForNextSuggestion = false;
|
||||
mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds);
|
||||
mDicNodeState.init();
|
||||
|
@ -115,12 +115,11 @@ class DicNode {
|
|||
// Init for root with previous word
|
||||
void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
|
||||
mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
|
||||
int newPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> newPrevWordIds;
|
||||
newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId();
|
||||
for (size_t i = 1; i < NELEMS(newPrevWordIds); ++i) {
|
||||
newPrevWordIds[i] = dicNode->getPrevWordIds()[i - 1];
|
||||
}
|
||||
mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordIds);
|
||||
dicNode->getPrevWordIds().limit(newPrevWordIds.size() - 1)
|
||||
.copyToArray(&newPrevWordIds, 1 /* offset */);
|
||||
mDicNodeProperties.init(rootPtNodeArrayPos, WordIdArrayView::fromArray(newPrevWordIds));
|
||||
mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
|
||||
dicNode->mDicNodeProperties.getDepth());
|
||||
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
|
||||
|
@ -203,8 +202,7 @@ class DicNode {
|
|||
return mDicNodeProperties.getWordId();
|
||||
}
|
||||
|
||||
// TODO: Use view class to return word id array.
|
||||
const int *getPrevWordIds() const {
|
||||
const WordIdArrayView getPrevWordIds() const {
|
||||
return mDicNodeProperties.getPrevWordIds();
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace latinime {
|
|||
|
||||
/* static */ void DicNodeUtils::initAsRoot(
|
||||
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
|
||||
const int *const prevWordIds, DicNode *const newRootDicNode) {
|
||||
const WordIdArrayView prevWordIds, DicNode *const newRootDicNode) {
|
||||
newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds);
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define LATINIME_DIC_NODE_UTILS_H
|
||||
|
||||
#include "defines.h"
|
||||
#include "utils/int_array_view.h"
|
||||
|
||||
namespace latinime {
|
||||
|
||||
|
@ -30,7 +31,7 @@ class DicNodeUtils {
|
|||
public:
|
||||
static void initAsRoot(
|
||||
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
|
||||
const int *const prevWordIds, DicNode *const newRootDicNode);
|
||||
const WordIdArrayView prevWordIds, DicNode *const newRootDicNode);
|
||||
static void initAsRootWithPreviousWord(
|
||||
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
|
||||
const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode);
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
#define LATINIME_DIC_NODE_PROPERTIES_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "defines.h"
|
||||
#include "utils/int_array_view.h"
|
||||
|
||||
namespace latinime {
|
||||
|
||||
|
@ -36,23 +38,23 @@ class DicNodeProperties {
|
|||
|
||||
// Should be called only once per DicNode is initialized.
|
||||
void init(const int childrenPos, const int nodeCodePoint, const int wordId,
|
||||
const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordIds) {
|
||||
const uint16_t depth, const uint16_t leavingDepth, const WordIdArrayView prevWordIds) {
|
||||
mChildrenPtNodeArrayPos = childrenPos;
|
||||
mDicNodeCodePoint = nodeCodePoint;
|
||||
mWordId = wordId;
|
||||
mDepth = depth;
|
||||
mLeavingDepth = leavingDepth;
|
||||
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
|
||||
prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */);
|
||||
}
|
||||
|
||||
// Init for root with prevWordsPtNodePos which is used for n-gram
|
||||
void init(const int rootPtNodeArrayPos, const int *const prevWordIds) {
|
||||
void init(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) {
|
||||
mChildrenPtNodeArrayPos = rootPtNodeArrayPos;
|
||||
mDicNodeCodePoint = NOT_A_CODE_POINT;
|
||||
mWordId = NOT_A_WORD_ID;
|
||||
mDepth = 0;
|
||||
mLeavingDepth = 0;
|
||||
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
|
||||
prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */);
|
||||
}
|
||||
|
||||
void initByCopy(const DicNodeProperties *const dicNodeProp) {
|
||||
|
@ -61,7 +63,8 @@ class DicNodeProperties {
|
|||
mWordId = dicNodeProp->mWordId;
|
||||
mDepth = dicNodeProp->mDepth;
|
||||
mLeavingDepth = dicNodeProp->mLeavingDepth;
|
||||
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
|
||||
WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds)
|
||||
.copyToArray(&mPrevWordIds, 0 /* offset */);
|
||||
}
|
||||
|
||||
// Init as passing child
|
||||
|
@ -71,7 +74,8 @@ class DicNodeProperties {
|
|||
mWordId = dicNodeProp->mWordId;
|
||||
mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child
|
||||
mLeavingDepth = dicNodeProp->mLeavingDepth;
|
||||
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
|
||||
WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds)
|
||||
.copyToArray(&mPrevWordIds, 0 /* offset */);
|
||||
}
|
||||
|
||||
int getChildrenPtNodeArrayPos() const {
|
||||
|
@ -99,8 +103,8 @@ class DicNodeProperties {
|
|||
return (mChildrenPtNodeArrayPos != NOT_A_DICT_POS) || mDepth != mLeavingDepth;
|
||||
}
|
||||
|
||||
const int *getPrevWordIds() const {
|
||||
return mPrevWordIds;
|
||||
const WordIdArrayView getPrevWordIds() const {
|
||||
return WordIdArrayView::fromArray(mPrevWordIds);
|
||||
}
|
||||
|
||||
int getWordId() const {
|
||||
|
@ -116,7 +120,7 @@ class DicNodeProperties {
|
|||
int mWordId;
|
||||
uint16_t mDepth;
|
||||
uint16_t mLeavingDepth;
|
||||
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIds;
|
||||
};
|
||||
} // namespace latinime
|
||||
#endif // LATINIME_DIC_NODE_PROPERTIES_H
|
||||
|
|
|
@ -85,7 +85,7 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
|
|||
return;
|
||||
}
|
||||
const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
|
||||
mPrevWordIds.data(), targetWordId, nullptr /* multiBigramMap */);
|
||||
mPrevWordIds, targetWordId, nullptr /* multiBigramMap */);
|
||||
mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
|
||||
wordAttributes.getProbability());
|
||||
}
|
||||
|
@ -93,13 +93,13 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
|
|||
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
|
||||
SuggestionResults *const outSuggestionResults) const {
|
||||
TimeKeeper::setCurrentTime();
|
||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
|
||||
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(),
|
||||
true /* tryLowerCaseSearch */);
|
||||
NgramListenerForPrediction listener(prevWordsInfo,
|
||||
WordIdArrayView::fromFixedSizeArray(prevWordIds), outSuggestionResults,
|
||||
const WordIdArrayView prevWordIdArrayView = WordIdArrayView::fromArray(prevWordIds);
|
||||
NgramListenerForPrediction listener(prevWordsInfo, prevWordIdArrayView, outSuggestionResults,
|
||||
mDictionaryStructureWithBufferPolicy.get());
|
||||
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener);
|
||||
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIdArrayView, &listener);
|
||||
}
|
||||
|
||||
int Dictionary::getProbability(const int *word, int length) const {
|
||||
|
@ -119,13 +119,13 @@ int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, co
|
|||
CodePointArrayView(word, length), false /* forceLowerCaseSearch */);
|
||||
if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY;
|
||||
if (!prevWordsInfo) {
|
||||
return getDictionaryStructurePolicy()->getProbabilityOfWord(
|
||||
nullptr /* prevWordsPtNodePos */, wordId);
|
||||
return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId);
|
||||
}
|
||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
|
||||
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(),
|
||||
true /* tryLowerCaseSearch */);
|
||||
return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId);
|
||||
return getDictionaryStructurePolicy()->getProbabilityOfWord(
|
||||
IntArrayView::fromArray(prevWordIds), wordId);
|
||||
}
|
||||
|
||||
bool Dictionary::addUnigramEntry(const int *const word, const int length,
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "suggest/core/dictionary/digraph_utils.h"
|
||||
#include "suggest/core/session/prev_words_info.h"
|
||||
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
|
||||
#include "utils/int_array_view.h"
|
||||
|
||||
namespace latinime {
|
||||
|
||||
|
@ -34,11 +35,12 @@ namespace latinime {
|
|||
|
||||
// No prev words information.
|
||||
PrevWordsInfo emptyPrevWordsInfo;
|
||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds,
|
||||
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||
emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds.data(),
|
||||
false /* tryLowerCaseSearch */);
|
||||
current.emplace_back();
|
||||
DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, ¤t.front());
|
||||
DicNodeUtils::initAsRoot(dictionaryStructurePolicy,
|
||||
IntArrayView::fromArray(prevWordIds), ¤t.front());
|
||||
for (int i = 0; i < codePointCount; ++i) {
|
||||
// The base-lower input is used to ignore case errors and accent errors.
|
||||
const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]);
|
||||
|
|
|
@ -35,9 +35,9 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP =
|
|||
// Also caches the bigrams if there is space remaining and they have not been cached already.
|
||||
int MultiBigramMap::getBigramProbability(
|
||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||
const int *const prevWordIds, const int nextWordId,
|
||||
const WordIdArrayView prevWordIds, const int nextWordId,
|
||||
const int unigramProbability) {
|
||||
if (!prevWordIds || prevWordIds[0] == NOT_A_WORD_ID) {
|
||||
if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) {
|
||||
return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY);
|
||||
}
|
||||
const auto mapPosition = mBigramMaps.find(prevWordIds[0]);
|
||||
|
@ -56,7 +56,7 @@ int MultiBigramMap::getBigramProbability(
|
|||
|
||||
void MultiBigramMap::BigramMap::init(
|
||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||
const int *const prevWordIds) {
|
||||
const WordIdArrayView prevWordIds) {
|
||||
structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */);
|
||||
}
|
||||
|
||||
|
@ -83,16 +83,13 @@ void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, const i
|
|||
|
||||
void MultiBigramMap::addBigramsForWord(
|
||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||
const int *const prevWordIds) {
|
||||
if (prevWordIds) {
|
||||
const WordIdArrayView prevWordIds) {
|
||||
mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds);
|
||||
}
|
||||
}
|
||||
|
||||
int MultiBigramMap::readBigramProbabilityFromBinaryDictionary(
|
||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||
const int *const prevWordIds, const int nextWordId,
|
||||
const int unigramProbability) {
|
||||
const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability) {
|
||||
const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId);
|
||||
if (bigramProbability != NOT_A_PROBABILITY) {
|
||||
return bigramProbability;
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "suggest/core/dictionary/bloom_filter.h"
|
||||
#include "suggest/core/dictionary/ngram_listener.h"
|
||||
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
|
||||
#include "utils/int_array_view.h"
|
||||
|
||||
namespace latinime {
|
||||
|
||||
|
@ -39,7 +40,7 @@ class MultiBigramMap {
|
|||
// Look up the bigram probability for the given word pair from the cached bigram maps.
|
||||
// Also caches the bigrams if there is space remaining and they have not been cached already.
|
||||
int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||
const int *const prevWordIds, const int nextWordId, const int unigramProbability);
|
||||
const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability);
|
||||
|
||||
void clear() {
|
||||
mBigramMaps.clear();
|
||||
|
@ -57,7 +58,7 @@ class MultiBigramMap {
|
|||
virtual ~BigramMap() {}
|
||||
|
||||
void init(const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||
const int *const prevWordIds);
|
||||
const WordIdArrayView prevWordIds);
|
||||
int getBigramProbability(
|
||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||
const int nextWordId, const int unigramProbability) const;
|
||||
|
@ -70,11 +71,11 @@ class MultiBigramMap {
|
|||
};
|
||||
|
||||
void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||
const int *const prevWordIds);
|
||||
const WordIdArrayView prevWordIds);
|
||||
|
||||
int readBigramProbabilityFromBinaryDictionary(
|
||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||
const int *const prevWordIds, const int nextWordId, const int unigramProbability);
|
||||
const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability);
|
||||
|
||||
static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP;
|
||||
std::unordered_map<int, BigramMap> mBigramMaps;
|
||||
|
|
|
@ -58,15 +58,15 @@ class DictionaryStructureWithBufferPolicy {
|
|||
virtual int getWordId(const CodePointArrayView wordCodePoints,
|
||||
const bool forceLowerCaseSearch) const = 0;
|
||||
|
||||
virtual const WordAttributes getWordAttributesInContext(const int *const prevWordIds,
|
||||
virtual const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
|
||||
const int wordId, MultiBigramMap *const multiBigramMap) const = 0;
|
||||
|
||||
// TODO: Remove
|
||||
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
|
||||
|
||||
virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0;
|
||||
virtual int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const = 0;
|
||||
|
||||
virtual void iterateNgramEntries(const int *const prevWordIds,
|
||||
virtual void iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||
NgramListener *const listener) const = 0;
|
||||
|
||||
virtual BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const = 0;
|
||||
|
|
|
@ -35,7 +35,7 @@ void DicTraverseSession::init(const Dictionary *const dictionary,
|
|||
mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy()
|
||||
->getMultiWordCostMultiplier();
|
||||
mSuggestOptions = suggestOptions;
|
||||
prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds,
|
||||
prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds.data(),
|
||||
true /* tryLowerCaseSearch */);
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "suggest/core/dicnode/dic_nodes_cache.h"
|
||||
#include "suggest/core/dictionary/multi_bigram_map.h"
|
||||
#include "suggest/core/layout/proximity_info_state.h"
|
||||
#include "utils/int_array_view.h"
|
||||
|
||||
namespace latinime {
|
||||
|
||||
|
@ -55,9 +56,7 @@ class DicTraverseSession {
|
|||
mMultiWordCostMultiplier(1.0f) {
|
||||
// NOTE: mProximityInfoStates is an array of instances.
|
||||
// No need to initialize it explicitly here.
|
||||
for (size_t i = 0; i < NELEMS(mPrevWordsIds); ++i) {
|
||||
mPrevWordsIds[i] = NOT_A_DICT_POS;
|
||||
}
|
||||
mPrevWordsIds.fill(NOT_A_DICT_POS);
|
||||
}
|
||||
|
||||
// Non virtual inline destructor -- never inherit this class
|
||||
|
@ -79,7 +78,7 @@ class DicTraverseSession {
|
|||
//--------------------
|
||||
const ProximityInfo *getProximityInfo() const { return mProximityInfo; }
|
||||
const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; }
|
||||
const int *getPrevWordIds() const { return mPrevWordsIds; }
|
||||
const WordIdArrayView getPrevWordIds() const { return IntArrayView::fromArray(mPrevWordsIds); }
|
||||
DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
|
||||
MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; }
|
||||
const ProximityInfoState *getProximityInfoState(int id) const {
|
||||
|
@ -166,7 +165,7 @@ class DicTraverseSession {
|
|||
const int *const inputYs, const int *const times, const int *const pointerIds,
|
||||
const int inputSize, const float maxSpatialDistance, const int maxPointerCount);
|
||||
|
||||
int mPrevWordsIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordsIds;
|
||||
const ProximityInfo *mProximityInfo;
|
||||
const Dictionary *mDictionary;
|
||||
const SuggestOptions *mSuggestOptions;
|
||||
|
|
|
@ -116,7 +116,7 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
|||
}
|
||||
|
||||
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
||||
const int *const prevWordIds, const int wordId,
|
||||
const WordIdArrayView prevWordIds, const int wordId,
|
||||
MultiBigramMap *const multiBigramMap) const {
|
||||
if (wordId == NOT_A_WORD_ID) {
|
||||
return WordAttributes();
|
||||
|
@ -128,7 +128,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
|||
prevWordIds, wordId, ptNodeParams.getProbability());
|
||||
return getWordAttributes(probability, ptNodeParams);
|
||||
}
|
||||
if (prevWordIds) {
|
||||
if (!prevWordIds.empty()) {
|
||||
const int probability = getProbabilityOfWord(prevWordIds, wordId);
|
||||
if (probability != NOT_A_PROBABILITY) {
|
||||
return getWordAttributes(probability, ptNodeParams);
|
||||
|
@ -160,7 +160,7 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
|||
}
|
||||
}
|
||||
|
||||
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
||||
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
|
||||
const int wordId) const {
|
||||
if (wordId == NOT_A_WORD_ID) {
|
||||
return NOT_A_PROBABILITY;
|
||||
|
@ -170,7 +170,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
|||
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
|
||||
return NOT_A_PROBABILITY;
|
||||
}
|
||||
if (prevWordIds) {
|
||||
if (!prevWordIds.empty()) {
|
||||
const int bigramsPosition = getBigramsPositionOfPtNode(
|
||||
getTerminalPtNodePosFromWordId(prevWordIds[0]));
|
||||
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
|
||||
|
@ -186,9 +186,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
|||
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
||||
}
|
||||
|
||||
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
|
||||
void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||
NgramListener *const listener) const {
|
||||
if (!prevWordIds) {
|
||||
if (prevWordIds.empty()) {
|
||||
return;
|
||||
}
|
||||
const int bigramsPosition = getBigramsPositionOfPtNode(
|
||||
|
|
|
@ -91,14 +91,15 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
|||
|
||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||
|
||||
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
|
||||
MultiBigramMap *const multiBigramMap) const;
|
||||
const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
|
||||
const int wordId, MultiBigramMap *const multiBigramMap) const;
|
||||
|
||||
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
||||
|
||||
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
|
||||
int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const;
|
||||
|
||||
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
|
||||
void iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||
NgramListener *const listener) const;
|
||||
|
||||
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;
|
||||
|
||||
|
|
|
@ -282,8 +282,9 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
|||
return getWordIdFromTerminalPtNodePos(ptNodePos);
|
||||
}
|
||||
|
||||
const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *const prevWordIds,
|
||||
const int wordId, MultiBigramMap *const multiBigramMap) const {
|
||||
const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(
|
||||
const WordIdArrayView prevWordIds, const int wordId,
|
||||
MultiBigramMap *const multiBigramMap) const {
|
||||
if (wordId == NOT_A_WORD_ID) {
|
||||
return WordAttributes();
|
||||
}
|
||||
|
@ -295,7 +296,7 @@ const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *c
|
|||
prevWordIds, wordId, ptNodeParams.getProbability());
|
||||
return getWordAttributes(probability, ptNodeParams);
|
||||
}
|
||||
if (prevWordIds) {
|
||||
if (!prevWordIds.empty()) {
|
||||
const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId);
|
||||
if (bigramProbability != NOT_A_PROBABILITY) {
|
||||
return getWordAttributes(bigramProbability, ptNodeParams);
|
||||
|
@ -327,7 +328,8 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability,
|
|||
}
|
||||
}
|
||||
|
||||
int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const int wordId) const {
|
||||
int PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
|
||||
const int wordId) const {
|
||||
if (wordId == NOT_A_WORD_ID) {
|
||||
return NOT_A_PROBABILITY;
|
||||
}
|
||||
|
@ -340,7 +342,7 @@ int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const
|
|||
// for shortcuts).
|
||||
return NOT_A_PROBABILITY;
|
||||
}
|
||||
if (prevWordIds) {
|
||||
if (!prevWordIds.empty()) {
|
||||
const int bigramsPosition = getBigramsPositionOfPtNode(
|
||||
getTerminalPtNodePosFromWordId(prevWordIds[0]));
|
||||
BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition);
|
||||
|
@ -356,9 +358,9 @@ int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const
|
|||
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
||||
}
|
||||
|
||||
void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
|
||||
void PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||
NgramListener *const listener) const {
|
||||
if (!prevWordIds) {
|
||||
if (prevWordIds.empty()) {
|
||||
return;
|
||||
}
|
||||
const int bigramsPosition = getBigramsPositionOfPtNode(
|
||||
|
@ -371,8 +373,7 @@ void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
|
|||
}
|
||||
}
|
||||
|
||||
BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator(
|
||||
const int wordId) const {
|
||||
BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator(const int wordId) const {
|
||||
const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId));
|
||||
return BinaryDictionaryShortcutIterator(&mShortcutListPolicy, shortcutPos);
|
||||
}
|
||||
|
|
|
@ -66,14 +66,15 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
|||
|
||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||
|
||||
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
|
||||
MultiBigramMap *const multiBigramMap) const;
|
||||
const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
|
||||
const int wordId, MultiBigramMap *const multiBigramMap) const;
|
||||
|
||||
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
||||
|
||||
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
|
||||
int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const;
|
||||
|
||||
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
|
||||
void iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||
NgramListener *const listener) const;
|
||||
|
||||
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h"
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "suggest/core/dicnode/dic_node.h"
|
||||
|
@ -111,7 +112,7 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
|||
}
|
||||
|
||||
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
||||
const int *const prevWordIds, const int wordId,
|
||||
const WordIdArrayView prevWordIds, const int wordId,
|
||||
MultiBigramMap *const multiBigramMap) const {
|
||||
if (wordId == NOT_A_WORD_ID) {
|
||||
return WordAttributes();
|
||||
|
@ -121,27 +122,11 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
|||
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
||||
// TODO: Support n-gram.
|
||||
return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
|
||||
WordIdArrayView::singleElementView(prevWordIds), wordId), ptNodeParams.isBlacklisted(),
|
||||
prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(),
|
||||
ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
|
||||
}
|
||||
|
||||
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
||||
const int bigramProbability) const {
|
||||
if (mHeaderPolicy->isDecayingDict()) {
|
||||
// Both probabilities are encoded. Decode them and get probability.
|
||||
return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability);
|
||||
} else {
|
||||
if (unigramProbability == NOT_A_PROBABILITY) {
|
||||
return NOT_A_PROBABILITY;
|
||||
} else if (bigramProbability == NOT_A_PROBABILITY) {
|
||||
return ProbabilityUtils::backoff(unigramProbability);
|
||||
} else {
|
||||
return bigramProbability;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
||||
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
|
||||
const int wordId) const {
|
||||
if (wordId == NOT_A_WORD_ID) {
|
||||
return NOT_A_PROBABILITY;
|
||||
|
@ -152,11 +137,10 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
|||
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
|
||||
return NOT_A_PROBABILITY;
|
||||
}
|
||||
if (prevWordIds) {
|
||||
// TODO: Support n-gram.
|
||||
const ProbabilityEntry probabilityEntry =
|
||||
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
|
||||
IntArrayView::singleElementView(prevWordIds), wordId);
|
||||
prevWordIds.limit(1 /* maxSize */), wordId);
|
||||
if (!probabilityEntry.isValid()) {
|
||||
return NOT_A_PROBABILITY;
|
||||
}
|
||||
|
@ -167,8 +151,6 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
|||
return probabilityEntry.getProbability();
|
||||
}
|
||||
}
|
||||
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
||||
}
|
||||
|
||||
BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
|
||||
const int wordId) const {
|
||||
|
@ -176,15 +158,15 @@ BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
|
|||
return BinaryDictionaryShortcutIterator(&mShortcutPolicy, shortcutPos);
|
||||
}
|
||||
|
||||
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
|
||||
void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||
NgramListener *const listener) const {
|
||||
if (!prevWordIds) {
|
||||
if (prevWordIds.empty()) {
|
||||
return;
|
||||
}
|
||||
// TODO: Support n-gram.
|
||||
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
|
||||
for (const auto entry : languageModelDictContent->getProbabilityEntries(
|
||||
WordIdArrayView::singleElementView(prevWordIds))) {
|
||||
prevWordIds.limit(1 /* maxSize */))) {
|
||||
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
|
||||
const int probability = probabilityEntry.hasHistoricalInfo() ?
|
||||
ForgettingCurveUtils::decodeProbability(
|
||||
|
@ -321,8 +303,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
|
|||
"length: %zd", bigramProperty->getTargetCodePoints()->size());
|
||||
return false;
|
||||
}
|
||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
|
||||
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||
prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */);
|
||||
// TODO: Support N-gram.
|
||||
if (prevWordIds[0] == NOT_A_WORD_ID) {
|
||||
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
|
||||
|
@ -337,7 +319,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
|
|||
return false;
|
||||
}
|
||||
// Refresh word ids.
|
||||
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
|
||||
prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
@ -348,14 +330,14 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
|
|||
return false;
|
||||
}
|
||||
bool addedNewEntry = false;
|
||||
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
|
||||
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos;
|
||||
for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) {
|
||||
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
|
||||
->getTerminalPtNodePosition(prevWordIds[i]);
|
||||
}
|
||||
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
|
||||
->getTerminalPtNodePosition(wordId);
|
||||
if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
|
||||
if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos),
|
||||
wordPtNodePos, bigramProperty, &addedNewEntry)) {
|
||||
if (addedNewEntry) {
|
||||
mBigramCount++;
|
||||
|
@ -385,8 +367,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
|
|||
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
|
||||
wordCodePoints.size());
|
||||
}
|
||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */);
|
||||
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||
prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSerch */);
|
||||
// TODO: Support N-gram.
|
||||
if (prevWordIds[0] == NOT_A_WORD_ID) {
|
||||
return false;
|
||||
|
@ -395,14 +377,14 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
|
|||
if (wordId == NOT_A_WORD_ID) {
|
||||
return false;
|
||||
}
|
||||
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
|
||||
std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos;
|
||||
for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) {
|
||||
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
|
||||
->getTerminalPtNodePosition(prevWordIds[i]);
|
||||
}
|
||||
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
|
||||
->getTerminalPtNodePosition(wordId);
|
||||
if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
|
||||
if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos),
|
||||
wordPtNodePos)) {
|
||||
mBigramCount--;
|
||||
return true;
|
||||
|
|
|
@ -68,14 +68,19 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
|||
|
||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||
|
||||
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
|
||||
MultiBigramMap *const multiBigramMap) const;
|
||||
const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
|
||||
const int wordId, MultiBigramMap *const multiBigramMap) const;
|
||||
|
||||
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
||||
// TODO: Remove
|
||||
int getProbability(const int unigramProbability, const int bigramProbability) const {
|
||||
// Not used.
|
||||
return NOT_A_PROBABILITY;
|
||||
}
|
||||
|
||||
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
|
||||
int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const;
|
||||
|
||||
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
|
||||
void iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||
NgramListener *const listener) const;
|
||||
|
||||
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;
|
||||
|
||||
|
|
|
@ -57,9 +57,9 @@ class IntArrayView {
|
|||
explicit IntArrayView(const std::vector<int> &vector)
|
||||
: mPtr(vector.data()), mSize(vector.size()) {}
|
||||
|
||||
template <int N>
|
||||
AK_FORCE_INLINE static IntArrayView fromFixedSizeArray(const int (&array)[N]) {
|
||||
return IntArrayView(array, N);
|
||||
template <size_t N>
|
||||
AK_FORCE_INLINE static IntArrayView fromArray(const std::array<int, N> &array) {
|
||||
return IntArrayView(array.data(), array.size());
|
||||
}
|
||||
|
||||
// Returns a view that points one int object.
|
||||
|
@ -120,6 +120,8 @@ class IntArrayView {
|
|||
using WordIdArrayView = IntArrayView;
|
||||
using PtNodePosArrayView = IntArrayView;
|
||||
using CodePointArrayView = IntArrayView;
|
||||
template <size_t size>
|
||||
using WordIdArray = std::array<int, size>;
|
||||
|
||||
} // namespace latinime
|
||||
#endif // LATINIME_MEMORY_VIEW_H
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <array>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "utils/int_array_view.h"
|
||||
|
@ -97,8 +98,8 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
|
|||
const int bigramProbability = 20;
|
||||
const int trigramProbability = 30;
|
||||
const int wordId = 100;
|
||||
const int prevWordIdArray[] = { 1, 2 };
|
||||
const WordIdArrayView prevWordIds = WordIdArrayView::fromFixedSizeArray(prevWordIdArray);
|
||||
const std::array<int, 2> prevWordIdArray = {{ 1, 2 }};
|
||||
const WordIdArrayView prevWordIds = WordIdArrayView::fromArray(prevWordIdArray);
|
||||
|
||||
const ProbabilityEntry probabilityEntry(flag, probability);
|
||||
languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
||||
|
|
|
@ -46,8 +46,8 @@ TEST(IntArrayViewTest, TestIteration) {
|
|||
|
||||
TEST(IntArrayViewTest, TestConstructFromArray) {
|
||||
const size_t ARRAY_SIZE = 100;
|
||||
int intArray[ARRAY_SIZE];
|
||||
const auto intArrayView = IntArrayView::fromFixedSizeArray(intArray);
|
||||
std::array<int, ARRAY_SIZE> intArray;
|
||||
const auto intArrayView = IntArrayView::fromArray(intArray);
|
||||
EXPECT_EQ(ARRAY_SIZE, intArrayView.size());
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue