Use WordIdArrayView for prevWordIds.
Bug: 14425059 Change-Id: Ia84fb997d89564e60111b46ca83bbfa3b187f316main
parent
a3b0eb1685
commit
537f6eea8a
|
@ -105,7 +105,7 @@ class DicNode {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init for root with prevWordIds which is used for n-gram
|
// Init for root with prevWordIds which is used for n-gram
|
||||||
void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordIds) {
|
void initAsRoot(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) {
|
||||||
mIsCachedForNextSuggestion = false;
|
mIsCachedForNextSuggestion = false;
|
||||||
mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds);
|
mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds);
|
||||||
mDicNodeState.init();
|
mDicNodeState.init();
|
||||||
|
@ -115,12 +115,11 @@ class DicNode {
|
||||||
// Init for root with previous word
|
// Init for root with previous word
|
||||||
void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
|
void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
|
||||||
mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
|
mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
|
||||||
int newPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> newPrevWordIds;
|
||||||
newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId();
|
newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId();
|
||||||
for (size_t i = 1; i < NELEMS(newPrevWordIds); ++i) {
|
dicNode->getPrevWordIds().limit(newPrevWordIds.size() - 1)
|
||||||
newPrevWordIds[i] = dicNode->getPrevWordIds()[i - 1];
|
.copyToArray(&newPrevWordIds, 1 /* offset */);
|
||||||
}
|
mDicNodeProperties.init(rootPtNodeArrayPos, WordIdArrayView::fromArray(newPrevWordIds));
|
||||||
mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordIds);
|
|
||||||
mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
|
mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
|
||||||
dicNode->mDicNodeProperties.getDepth());
|
dicNode->mDicNodeProperties.getDepth());
|
||||||
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
|
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
|
||||||
|
@ -203,8 +202,7 @@ class DicNode {
|
||||||
return mDicNodeProperties.getWordId();
|
return mDicNodeProperties.getWordId();
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Use view class to return word id array.
|
const WordIdArrayView getPrevWordIds() const {
|
||||||
const int *getPrevWordIds() const {
|
|
||||||
return mDicNodeProperties.getPrevWordIds();
|
return mDicNodeProperties.getPrevWordIds();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ namespace latinime {
|
||||||
|
|
||||||
/* static */ void DicNodeUtils::initAsRoot(
|
/* static */ void DicNodeUtils::initAsRoot(
|
||||||
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
|
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
|
||||||
const int *const prevWordIds, DicNode *const newRootDicNode) {
|
const WordIdArrayView prevWordIds, DicNode *const newRootDicNode) {
|
||||||
newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds);
|
newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#define LATINIME_DIC_NODE_UTILS_H
|
#define LATINIME_DIC_NODE_UTILS_H
|
||||||
|
|
||||||
#include "defines.h"
|
#include "defines.h"
|
||||||
|
#include "utils/int_array_view.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
|
@ -30,7 +31,7 @@ class DicNodeUtils {
|
||||||
public:
|
public:
|
||||||
static void initAsRoot(
|
static void initAsRoot(
|
||||||
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
|
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
|
||||||
const int *const prevWordIds, DicNode *const newRootDicNode);
|
const WordIdArrayView prevWordIds, DicNode *const newRootDicNode);
|
||||||
static void initAsRootWithPreviousWord(
|
static void initAsRootWithPreviousWord(
|
||||||
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
|
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
|
||||||
const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode);
|
const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode);
|
||||||
|
|
|
@ -18,8 +18,10 @@
|
||||||
#define LATINIME_DIC_NODE_PROPERTIES_H
|
#define LATINIME_DIC_NODE_PROPERTIES_H
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
#include "defines.h"
|
#include "defines.h"
|
||||||
|
#include "utils/int_array_view.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
|
@ -36,23 +38,23 @@ class DicNodeProperties {
|
||||||
|
|
||||||
// Should be called only once per DicNode is initialized.
|
// Should be called only once per DicNode is initialized.
|
||||||
void init(const int childrenPos, const int nodeCodePoint, const int wordId,
|
void init(const int childrenPos, const int nodeCodePoint, const int wordId,
|
||||||
const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordIds) {
|
const uint16_t depth, const uint16_t leavingDepth, const WordIdArrayView prevWordIds) {
|
||||||
mChildrenPtNodeArrayPos = childrenPos;
|
mChildrenPtNodeArrayPos = childrenPos;
|
||||||
mDicNodeCodePoint = nodeCodePoint;
|
mDicNodeCodePoint = nodeCodePoint;
|
||||||
mWordId = wordId;
|
mWordId = wordId;
|
||||||
mDepth = depth;
|
mDepth = depth;
|
||||||
mLeavingDepth = leavingDepth;
|
mLeavingDepth = leavingDepth;
|
||||||
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
|
prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init for root with prevWordsPtNodePos which is used for n-gram
|
// Init for root with prevWordsPtNodePos which is used for n-gram
|
||||||
void init(const int rootPtNodeArrayPos, const int *const prevWordIds) {
|
void init(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) {
|
||||||
mChildrenPtNodeArrayPos = rootPtNodeArrayPos;
|
mChildrenPtNodeArrayPos = rootPtNodeArrayPos;
|
||||||
mDicNodeCodePoint = NOT_A_CODE_POINT;
|
mDicNodeCodePoint = NOT_A_CODE_POINT;
|
||||||
mWordId = NOT_A_WORD_ID;
|
mWordId = NOT_A_WORD_ID;
|
||||||
mDepth = 0;
|
mDepth = 0;
|
||||||
mLeavingDepth = 0;
|
mLeavingDepth = 0;
|
||||||
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
|
prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */);
|
||||||
}
|
}
|
||||||
|
|
||||||
void initByCopy(const DicNodeProperties *const dicNodeProp) {
|
void initByCopy(const DicNodeProperties *const dicNodeProp) {
|
||||||
|
@ -61,7 +63,8 @@ class DicNodeProperties {
|
||||||
mWordId = dicNodeProp->mWordId;
|
mWordId = dicNodeProp->mWordId;
|
||||||
mDepth = dicNodeProp->mDepth;
|
mDepth = dicNodeProp->mDepth;
|
||||||
mLeavingDepth = dicNodeProp->mLeavingDepth;
|
mLeavingDepth = dicNodeProp->mLeavingDepth;
|
||||||
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
|
WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds)
|
||||||
|
.copyToArray(&mPrevWordIds, 0 /* offset */);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init as passing child
|
// Init as passing child
|
||||||
|
@ -71,7 +74,8 @@ class DicNodeProperties {
|
||||||
mWordId = dicNodeProp->mWordId;
|
mWordId = dicNodeProp->mWordId;
|
||||||
mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child
|
mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child
|
||||||
mLeavingDepth = dicNodeProp->mLeavingDepth;
|
mLeavingDepth = dicNodeProp->mLeavingDepth;
|
||||||
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
|
WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds)
|
||||||
|
.copyToArray(&mPrevWordIds, 0 /* offset */);
|
||||||
}
|
}
|
||||||
|
|
||||||
int getChildrenPtNodeArrayPos() const {
|
int getChildrenPtNodeArrayPos() const {
|
||||||
|
@ -99,8 +103,8 @@ class DicNodeProperties {
|
||||||
return (mChildrenPtNodeArrayPos != NOT_A_DICT_POS) || mDepth != mLeavingDepth;
|
return (mChildrenPtNodeArrayPos != NOT_A_DICT_POS) || mDepth != mLeavingDepth;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int *getPrevWordIds() const {
|
const WordIdArrayView getPrevWordIds() const {
|
||||||
return mPrevWordIds;
|
return WordIdArrayView::fromArray(mPrevWordIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
int getWordId() const {
|
int getWordId() const {
|
||||||
|
@ -116,7 +120,7 @@ class DicNodeProperties {
|
||||||
int mWordId;
|
int mWordId;
|
||||||
uint16_t mDepth;
|
uint16_t mDepth;
|
||||||
uint16_t mLeavingDepth;
|
uint16_t mLeavingDepth;
|
||||||
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIds;
|
||||||
};
|
};
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
#endif // LATINIME_DIC_NODE_PROPERTIES_H
|
#endif // LATINIME_DIC_NODE_PROPERTIES_H
|
||||||
|
|
|
@ -85,7 +85,7 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
|
const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
|
||||||
mPrevWordIds.data(), targetWordId, nullptr /* multiBigramMap */);
|
mPrevWordIds, targetWordId, nullptr /* multiBigramMap */);
|
||||||
mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
|
mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
|
||||||
wordAttributes.getProbability());
|
wordAttributes.getProbability());
|
||||||
}
|
}
|
||||||
|
@ -93,13 +93,13 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
|
||||||
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
|
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
|
||||||
SuggestionResults *const outSuggestionResults) const {
|
SuggestionResults *const outSuggestionResults) const {
|
||||||
TimeKeeper::setCurrentTime();
|
TimeKeeper::setCurrentTime();
|
||||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||||
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
|
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(),
|
||||||
true /* tryLowerCaseSearch */);
|
true /* tryLowerCaseSearch */);
|
||||||
NgramListenerForPrediction listener(prevWordsInfo,
|
const WordIdArrayView prevWordIdArrayView = WordIdArrayView::fromArray(prevWordIds);
|
||||||
WordIdArrayView::fromFixedSizeArray(prevWordIds), outSuggestionResults,
|
NgramListenerForPrediction listener(prevWordsInfo, prevWordIdArrayView, outSuggestionResults,
|
||||||
mDictionaryStructureWithBufferPolicy.get());
|
mDictionaryStructureWithBufferPolicy.get());
|
||||||
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener);
|
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIdArrayView, &listener);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Dictionary::getProbability(const int *word, int length) const {
|
int Dictionary::getProbability(const int *word, int length) const {
|
||||||
|
@ -119,13 +119,13 @@ int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, co
|
||||||
CodePointArrayView(word, length), false /* forceLowerCaseSearch */);
|
CodePointArrayView(word, length), false /* forceLowerCaseSearch */);
|
||||||
if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY;
|
if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY;
|
||||||
if (!prevWordsInfo) {
|
if (!prevWordsInfo) {
|
||||||
return getDictionaryStructurePolicy()->getProbabilityOfWord(
|
return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId);
|
||||||
nullptr /* prevWordsPtNodePos */, wordId);
|
|
||||||
}
|
}
|
||||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||||
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
|
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(),
|
||||||
true /* tryLowerCaseSearch */);
|
true /* tryLowerCaseSearch */);
|
||||||
return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId);
|
return getDictionaryStructurePolicy()->getProbabilityOfWord(
|
||||||
|
IntArrayView::fromArray(prevWordIds), wordId);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Dictionary::addUnigramEntry(const int *const word, const int length,
|
bool Dictionary::addUnigramEntry(const int *const word, const int length,
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "suggest/core/dictionary/digraph_utils.h"
|
#include "suggest/core/dictionary/digraph_utils.h"
|
||||||
#include "suggest/core/session/prev_words_info.h"
|
#include "suggest/core/session/prev_words_info.h"
|
||||||
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
|
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
|
||||||
|
#include "utils/int_array_view.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
|
@ -34,11 +35,12 @@ namespace latinime {
|
||||||
|
|
||||||
// No prev words information.
|
// No prev words information.
|
||||||
PrevWordsInfo emptyPrevWordsInfo;
|
PrevWordsInfo emptyPrevWordsInfo;
|
||||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||||
emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds,
|
emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds.data(),
|
||||||
false /* tryLowerCaseSearch */);
|
false /* tryLowerCaseSearch */);
|
||||||
current.emplace_back();
|
current.emplace_back();
|
||||||
DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, ¤t.front());
|
DicNodeUtils::initAsRoot(dictionaryStructurePolicy,
|
||||||
|
IntArrayView::fromArray(prevWordIds), ¤t.front());
|
||||||
for (int i = 0; i < codePointCount; ++i) {
|
for (int i = 0; i < codePointCount; ++i) {
|
||||||
// The base-lower input is used to ignore case errors and accent errors.
|
// The base-lower input is used to ignore case errors and accent errors.
|
||||||
const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]);
|
const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]);
|
||||||
|
|
|
@ -35,9 +35,9 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP =
|
||||||
// Also caches the bigrams if there is space remaining and they have not been cached already.
|
// Also caches the bigrams if there is space remaining and they have not been cached already.
|
||||||
int MultiBigramMap::getBigramProbability(
|
int MultiBigramMap::getBigramProbability(
|
||||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||||
const int *const prevWordIds, const int nextWordId,
|
const WordIdArrayView prevWordIds, const int nextWordId,
|
||||||
const int unigramProbability) {
|
const int unigramProbability) {
|
||||||
if (!prevWordIds || prevWordIds[0] == NOT_A_WORD_ID) {
|
if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) {
|
||||||
return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY);
|
return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY);
|
||||||
}
|
}
|
||||||
const auto mapPosition = mBigramMaps.find(prevWordIds[0]);
|
const auto mapPosition = mBigramMaps.find(prevWordIds[0]);
|
||||||
|
@ -56,7 +56,7 @@ int MultiBigramMap::getBigramProbability(
|
||||||
|
|
||||||
void MultiBigramMap::BigramMap::init(
|
void MultiBigramMap::BigramMap::init(
|
||||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||||
const int *const prevWordIds) {
|
const WordIdArrayView prevWordIds) {
|
||||||
structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */);
|
structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,16 +83,13 @@ void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, const i
|
||||||
|
|
||||||
void MultiBigramMap::addBigramsForWord(
|
void MultiBigramMap::addBigramsForWord(
|
||||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||||
const int *const prevWordIds) {
|
const WordIdArrayView prevWordIds) {
|
||||||
if (prevWordIds) {
|
|
||||||
mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds);
|
mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int MultiBigramMap::readBigramProbabilityFromBinaryDictionary(
|
int MultiBigramMap::readBigramProbabilityFromBinaryDictionary(
|
||||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||||
const int *const prevWordIds, const int nextWordId,
|
const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability) {
|
||||||
const int unigramProbability) {
|
|
||||||
const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId);
|
const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId);
|
||||||
if (bigramProbability != NOT_A_PROBABILITY) {
|
if (bigramProbability != NOT_A_PROBABILITY) {
|
||||||
return bigramProbability;
|
return bigramProbability;
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include "suggest/core/dictionary/bloom_filter.h"
|
#include "suggest/core/dictionary/bloom_filter.h"
|
||||||
#include "suggest/core/dictionary/ngram_listener.h"
|
#include "suggest/core/dictionary/ngram_listener.h"
|
||||||
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
|
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
|
||||||
|
#include "utils/int_array_view.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
|
@ -39,7 +40,7 @@ class MultiBigramMap {
|
||||||
// Look up the bigram probability for the given word pair from the cached bigram maps.
|
// Look up the bigram probability for the given word pair from the cached bigram maps.
|
||||||
// Also caches the bigrams if there is space remaining and they have not been cached already.
|
// Also caches the bigrams if there is space remaining and they have not been cached already.
|
||||||
int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||||
const int *const prevWordIds, const int nextWordId, const int unigramProbability);
|
const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability);
|
||||||
|
|
||||||
void clear() {
|
void clear() {
|
||||||
mBigramMaps.clear();
|
mBigramMaps.clear();
|
||||||
|
@ -57,7 +58,7 @@ class MultiBigramMap {
|
||||||
virtual ~BigramMap() {}
|
virtual ~BigramMap() {}
|
||||||
|
|
||||||
void init(const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
void init(const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||||
const int *const prevWordIds);
|
const WordIdArrayView prevWordIds);
|
||||||
int getBigramProbability(
|
int getBigramProbability(
|
||||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||||
const int nextWordId, const int unigramProbability) const;
|
const int nextWordId, const int unigramProbability) const;
|
||||||
|
@ -70,11 +71,11 @@ class MultiBigramMap {
|
||||||
};
|
};
|
||||||
|
|
||||||
void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||||
const int *const prevWordIds);
|
const WordIdArrayView prevWordIds);
|
||||||
|
|
||||||
int readBigramProbabilityFromBinaryDictionary(
|
int readBigramProbabilityFromBinaryDictionary(
|
||||||
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
const DictionaryStructureWithBufferPolicy *const structurePolicy,
|
||||||
const int *const prevWordIds, const int nextWordId, const int unigramProbability);
|
const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability);
|
||||||
|
|
||||||
static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP;
|
static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP;
|
||||||
std::unordered_map<int, BigramMap> mBigramMaps;
|
std::unordered_map<int, BigramMap> mBigramMaps;
|
||||||
|
|
|
@ -58,15 +58,15 @@ class DictionaryStructureWithBufferPolicy {
|
||||||
virtual int getWordId(const CodePointArrayView wordCodePoints,
|
virtual int getWordId(const CodePointArrayView wordCodePoints,
|
||||||
const bool forceLowerCaseSearch) const = 0;
|
const bool forceLowerCaseSearch) const = 0;
|
||||||
|
|
||||||
virtual const WordAttributes getWordAttributesInContext(const int *const prevWordIds,
|
virtual const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
|
||||||
const int wordId, MultiBigramMap *const multiBigramMap) const = 0;
|
const int wordId, MultiBigramMap *const multiBigramMap) const = 0;
|
||||||
|
|
||||||
// TODO: Remove
|
// TODO: Remove
|
||||||
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
|
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
|
||||||
|
|
||||||
virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0;
|
virtual int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const = 0;
|
||||||
|
|
||||||
virtual void iterateNgramEntries(const int *const prevWordIds,
|
virtual void iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||||
NgramListener *const listener) const = 0;
|
NgramListener *const listener) const = 0;
|
||||||
|
|
||||||
virtual BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const = 0;
|
virtual BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const = 0;
|
||||||
|
|
|
@ -35,7 +35,7 @@ void DicTraverseSession::init(const Dictionary *const dictionary,
|
||||||
mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy()
|
mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy()
|
||||||
->getMultiWordCostMultiplier();
|
->getMultiWordCostMultiplier();
|
||||||
mSuggestOptions = suggestOptions;
|
mSuggestOptions = suggestOptions;
|
||||||
prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds,
|
prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds.data(),
|
||||||
true /* tryLowerCaseSearch */);
|
true /* tryLowerCaseSearch */);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "suggest/core/dicnode/dic_nodes_cache.h"
|
#include "suggest/core/dicnode/dic_nodes_cache.h"
|
||||||
#include "suggest/core/dictionary/multi_bigram_map.h"
|
#include "suggest/core/dictionary/multi_bigram_map.h"
|
||||||
#include "suggest/core/layout/proximity_info_state.h"
|
#include "suggest/core/layout/proximity_info_state.h"
|
||||||
|
#include "utils/int_array_view.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
|
@ -55,9 +56,7 @@ class DicTraverseSession {
|
||||||
mMultiWordCostMultiplier(1.0f) {
|
mMultiWordCostMultiplier(1.0f) {
|
||||||
// NOTE: mProximityInfoStates is an array of instances.
|
// NOTE: mProximityInfoStates is an array of instances.
|
||||||
// No need to initialize it explicitly here.
|
// No need to initialize it explicitly here.
|
||||||
for (size_t i = 0; i < NELEMS(mPrevWordsIds); ++i) {
|
mPrevWordsIds.fill(NOT_A_DICT_POS);
|
||||||
mPrevWordsIds[i] = NOT_A_DICT_POS;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Non virtual inline destructor -- never inherit this class
|
// Non virtual inline destructor -- never inherit this class
|
||||||
|
@ -79,7 +78,7 @@ class DicTraverseSession {
|
||||||
//--------------------
|
//--------------------
|
||||||
const ProximityInfo *getProximityInfo() const { return mProximityInfo; }
|
const ProximityInfo *getProximityInfo() const { return mProximityInfo; }
|
||||||
const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; }
|
const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; }
|
||||||
const int *getPrevWordIds() const { return mPrevWordsIds; }
|
const WordIdArrayView getPrevWordIds() const { return IntArrayView::fromArray(mPrevWordsIds); }
|
||||||
DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
|
DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
|
||||||
MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; }
|
MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; }
|
||||||
const ProximityInfoState *getProximityInfoState(int id) const {
|
const ProximityInfoState *getProximityInfoState(int id) const {
|
||||||
|
@ -166,7 +165,7 @@ class DicTraverseSession {
|
||||||
const int *const inputYs, const int *const times, const int *const pointerIds,
|
const int *const inputYs, const int *const times, const int *const pointerIds,
|
||||||
const int inputSize, const float maxSpatialDistance, const int maxPointerCount);
|
const int inputSize, const float maxSpatialDistance, const int maxPointerCount);
|
||||||
|
|
||||||
int mPrevWordsIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordsIds;
|
||||||
const ProximityInfo *mProximityInfo;
|
const ProximityInfo *mProximityInfo;
|
||||||
const Dictionary *mDictionary;
|
const Dictionary *mDictionary;
|
||||||
const SuggestOptions *mSuggestOptions;
|
const SuggestOptions *mSuggestOptions;
|
||||||
|
|
|
@ -116,7 +116,7 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
||||||
}
|
}
|
||||||
|
|
||||||
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
||||||
const int *const prevWordIds, const int wordId,
|
const WordIdArrayView prevWordIds, const int wordId,
|
||||||
MultiBigramMap *const multiBigramMap) const {
|
MultiBigramMap *const multiBigramMap) const {
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return WordAttributes();
|
return WordAttributes();
|
||||||
|
@ -128,7 +128,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
||||||
prevWordIds, wordId, ptNodeParams.getProbability());
|
prevWordIds, wordId, ptNodeParams.getProbability());
|
||||||
return getWordAttributes(probability, ptNodeParams);
|
return getWordAttributes(probability, ptNodeParams);
|
||||||
}
|
}
|
||||||
if (prevWordIds) {
|
if (!prevWordIds.empty()) {
|
||||||
const int probability = getProbabilityOfWord(prevWordIds, wordId);
|
const int probability = getProbabilityOfWord(prevWordIds, wordId);
|
||||||
if (probability != NOT_A_PROBABILITY) {
|
if (probability != NOT_A_PROBABILITY) {
|
||||||
return getWordAttributes(probability, ptNodeParams);
|
return getWordAttributes(probability, ptNodeParams);
|
||||||
|
@ -160,7 +160,7 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
|
||||||
const int wordId) const {
|
const int wordId) const {
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return NOT_A_PROBABILITY;
|
return NOT_A_PROBABILITY;
|
||||||
|
@ -170,7 +170,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
||||||
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
|
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
|
||||||
return NOT_A_PROBABILITY;
|
return NOT_A_PROBABILITY;
|
||||||
}
|
}
|
||||||
if (prevWordIds) {
|
if (!prevWordIds.empty()) {
|
||||||
const int bigramsPosition = getBigramsPositionOfPtNode(
|
const int bigramsPosition = getBigramsPositionOfPtNode(
|
||||||
getTerminalPtNodePosFromWordId(prevWordIds[0]));
|
getTerminalPtNodePosFromWordId(prevWordIds[0]));
|
||||||
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
|
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
|
||||||
|
@ -186,9 +186,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
||||||
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
|
void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||||
NgramListener *const listener) const {
|
NgramListener *const listener) const {
|
||||||
if (!prevWordIds) {
|
if (prevWordIds.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int bigramsPosition = getBigramsPositionOfPtNode(
|
const int bigramsPosition = getBigramsPositionOfPtNode(
|
||||||
|
|
|
@ -91,14 +91,15 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
|
|
||||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||||
|
|
||||||
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
|
const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
|
||||||
MultiBigramMap *const multiBigramMap) const;
|
const int wordId, MultiBigramMap *const multiBigramMap) const;
|
||||||
|
|
||||||
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
||||||
|
|
||||||
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
|
int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const;
|
||||||
|
|
||||||
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
|
void iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||||
|
NgramListener *const listener) const;
|
||||||
|
|
||||||
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;
|
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;
|
||||||
|
|
||||||
|
|
|
@ -282,8 +282,9 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
||||||
return getWordIdFromTerminalPtNodePos(ptNodePos);
|
return getWordIdFromTerminalPtNodePos(ptNodePos);
|
||||||
}
|
}
|
||||||
|
|
||||||
const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *const prevWordIds,
|
const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(
|
||||||
const int wordId, MultiBigramMap *const multiBigramMap) const {
|
const WordIdArrayView prevWordIds, const int wordId,
|
||||||
|
MultiBigramMap *const multiBigramMap) const {
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return WordAttributes();
|
return WordAttributes();
|
||||||
}
|
}
|
||||||
|
@ -295,7 +296,7 @@ const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *c
|
||||||
prevWordIds, wordId, ptNodeParams.getProbability());
|
prevWordIds, wordId, ptNodeParams.getProbability());
|
||||||
return getWordAttributes(probability, ptNodeParams);
|
return getWordAttributes(probability, ptNodeParams);
|
||||||
}
|
}
|
||||||
if (prevWordIds) {
|
if (!prevWordIds.empty()) {
|
||||||
const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId);
|
const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId);
|
||||||
if (bigramProbability != NOT_A_PROBABILITY) {
|
if (bigramProbability != NOT_A_PROBABILITY) {
|
||||||
return getWordAttributes(bigramProbability, ptNodeParams);
|
return getWordAttributes(bigramProbability, ptNodeParams);
|
||||||
|
@ -327,7 +328,8 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const int wordId) const {
|
int PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
|
||||||
|
const int wordId) const {
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return NOT_A_PROBABILITY;
|
return NOT_A_PROBABILITY;
|
||||||
}
|
}
|
||||||
|
@ -340,7 +342,7 @@ int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const
|
||||||
// for shortcuts).
|
// for shortcuts).
|
||||||
return NOT_A_PROBABILITY;
|
return NOT_A_PROBABILITY;
|
||||||
}
|
}
|
||||||
if (prevWordIds) {
|
if (!prevWordIds.empty()) {
|
||||||
const int bigramsPosition = getBigramsPositionOfPtNode(
|
const int bigramsPosition = getBigramsPositionOfPtNode(
|
||||||
getTerminalPtNodePosFromWordId(prevWordIds[0]));
|
getTerminalPtNodePosFromWordId(prevWordIds[0]));
|
||||||
BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition);
|
BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition);
|
||||||
|
@ -356,9 +358,9 @@ int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const
|
||||||
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
|
void PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||||
NgramListener *const listener) const {
|
NgramListener *const listener) const {
|
||||||
if (!prevWordIds) {
|
if (prevWordIds.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int bigramsPosition = getBigramsPositionOfPtNode(
|
const int bigramsPosition = getBigramsPositionOfPtNode(
|
||||||
|
@ -371,8 +373,7 @@ void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator(
|
BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator(const int wordId) const {
|
||||||
const int wordId) const {
|
|
||||||
const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId));
|
const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId));
|
||||||
return BinaryDictionaryShortcutIterator(&mShortcutListPolicy, shortcutPos);
|
return BinaryDictionaryShortcutIterator(&mShortcutListPolicy, shortcutPos);
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,14 +66,15 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
|
|
||||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||||
|
|
||||||
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
|
const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
|
||||||
MultiBigramMap *const multiBigramMap) const;
|
const int wordId, MultiBigramMap *const multiBigramMap) const;
|
||||||
|
|
||||||
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
||||||
|
|
||||||
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
|
int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const;
|
||||||
|
|
||||||
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
|
void iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||||
|
NgramListener *const listener) const;
|
||||||
|
|
||||||
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;
|
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h"
|
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "suggest/core/dicnode/dic_node.h"
|
#include "suggest/core/dicnode/dic_node.h"
|
||||||
|
@ -111,7 +112,7 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
||||||
}
|
}
|
||||||
|
|
||||||
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
||||||
const int *const prevWordIds, const int wordId,
|
const WordIdArrayView prevWordIds, const int wordId,
|
||||||
MultiBigramMap *const multiBigramMap) const {
|
MultiBigramMap *const multiBigramMap) const {
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return WordAttributes();
|
return WordAttributes();
|
||||||
|
@ -121,27 +122,11 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
||||||
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
||||||
// TODO: Support n-gram.
|
// TODO: Support n-gram.
|
||||||
return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
|
return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
|
||||||
WordIdArrayView::singleElementView(prevWordIds), wordId), ptNodeParams.isBlacklisted(),
|
prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(),
|
||||||
ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
|
ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
|
||||||
const int bigramProbability) const {
|
|
||||||
if (mHeaderPolicy->isDecayingDict()) {
|
|
||||||
// Both probabilities are encoded. Decode them and get probability.
|
|
||||||
return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability);
|
|
||||||
} else {
|
|
||||||
if (unigramProbability == NOT_A_PROBABILITY) {
|
|
||||||
return NOT_A_PROBABILITY;
|
|
||||||
} else if (bigramProbability == NOT_A_PROBABILITY) {
|
|
||||||
return ProbabilityUtils::backoff(unigramProbability);
|
|
||||||
} else {
|
|
||||||
return bigramProbability;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
|
||||||
const int wordId) const {
|
const int wordId) const {
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return NOT_A_PROBABILITY;
|
return NOT_A_PROBABILITY;
|
||||||
|
@ -152,11 +137,10 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
||||||
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
|
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
|
||||||
return NOT_A_PROBABILITY;
|
return NOT_A_PROBABILITY;
|
||||||
}
|
}
|
||||||
if (prevWordIds) {
|
|
||||||
// TODO: Support n-gram.
|
// TODO: Support n-gram.
|
||||||
const ProbabilityEntry probabilityEntry =
|
const ProbabilityEntry probabilityEntry =
|
||||||
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
|
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
|
||||||
IntArrayView::singleElementView(prevWordIds), wordId);
|
prevWordIds.limit(1 /* maxSize */), wordId);
|
||||||
if (!probabilityEntry.isValid()) {
|
if (!probabilityEntry.isValid()) {
|
||||||
return NOT_A_PROBABILITY;
|
return NOT_A_PROBABILITY;
|
||||||
}
|
}
|
||||||
|
@ -166,8 +150,6 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
||||||
} else {
|
} else {
|
||||||
return probabilityEntry.getProbability();
|
return probabilityEntry.getProbability();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
|
BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
|
||||||
|
@ -176,15 +158,15 @@ BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
|
||||||
return BinaryDictionaryShortcutIterator(&mShortcutPolicy, shortcutPos);
|
return BinaryDictionaryShortcutIterator(&mShortcutPolicy, shortcutPos);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
|
void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||||
NgramListener *const listener) const {
|
NgramListener *const listener) const {
|
||||||
if (!prevWordIds) {
|
if (prevWordIds.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// TODO: Support n-gram.
|
// TODO: Support n-gram.
|
||||||
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
|
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
|
||||||
for (const auto entry : languageModelDictContent->getProbabilityEntries(
|
for (const auto entry : languageModelDictContent->getProbabilityEntries(
|
||||||
WordIdArrayView::singleElementView(prevWordIds))) {
|
prevWordIds.limit(1 /* maxSize */))) {
|
||||||
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
|
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
|
||||||
const int probability = probabilityEntry.hasHistoricalInfo() ?
|
const int probability = probabilityEntry.hasHistoricalInfo() ?
|
||||||
ForgettingCurveUtils::decodeProbability(
|
ForgettingCurveUtils::decodeProbability(
|
||||||
|
@ -321,8 +303,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
|
||||||
"length: %zd", bigramProperty->getTargetCodePoints()->size());
|
"length: %zd", bigramProperty->getTargetCodePoints()->size());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||||
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
|
prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */);
|
||||||
// TODO: Support N-gram.
|
// TODO: Support N-gram.
|
||||||
if (prevWordIds[0] == NOT_A_WORD_ID) {
|
if (prevWordIds[0] == NOT_A_WORD_ID) {
|
||||||
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
|
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
|
||||||
|
@ -337,7 +319,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Refresh word ids.
|
// Refresh word ids.
|
||||||
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
|
prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */);
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -348,14 +330,14 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool addedNewEntry = false;
|
bool addedNewEntry = false;
|
||||||
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos;
|
||||||
for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
|
for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) {
|
||||||
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
|
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
|
||||||
->getTerminalPtNodePosition(prevWordIds[i]);
|
->getTerminalPtNodePosition(prevWordIds[i]);
|
||||||
}
|
}
|
||||||
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
|
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
|
||||||
->getTerminalPtNodePosition(wordId);
|
->getTerminalPtNodePosition(wordId);
|
||||||
if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
|
if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos),
|
||||||
wordPtNodePos, bigramProperty, &addedNewEntry)) {
|
wordPtNodePos, bigramProperty, &addedNewEntry)) {
|
||||||
if (addedNewEntry) {
|
if (addedNewEntry) {
|
||||||
mBigramCount++;
|
mBigramCount++;
|
||||||
|
@ -385,8 +367,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
|
||||||
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
|
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
|
||||||
wordCodePoints.size());
|
wordCodePoints.size());
|
||||||
}
|
}
|
||||||
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
|
||||||
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */);
|
prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSerch */);
|
||||||
// TODO: Support N-gram.
|
// TODO: Support N-gram.
|
||||||
if (prevWordIds[0] == NOT_A_WORD_ID) {
|
if (prevWordIds[0] == NOT_A_WORD_ID) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -395,14 +377,14 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos;
|
||||||
for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
|
for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) {
|
||||||
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
|
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
|
||||||
->getTerminalPtNodePosition(prevWordIds[i]);
|
->getTerminalPtNodePosition(prevWordIds[i]);
|
||||||
}
|
}
|
||||||
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
|
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
|
||||||
->getTerminalPtNodePosition(wordId);
|
->getTerminalPtNodePosition(wordId);
|
||||||
if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
|
if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos),
|
||||||
wordPtNodePos)) {
|
wordPtNodePos)) {
|
||||||
mBigramCount--;
|
mBigramCount--;
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -68,14 +68,19 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
|
|
||||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||||
|
|
||||||
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
|
const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
|
||||||
MultiBigramMap *const multiBigramMap) const;
|
const int wordId, MultiBigramMap *const multiBigramMap) const;
|
||||||
|
|
||||||
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
// TODO: Remove
|
||||||
|
int getProbability(const int unigramProbability, const int bigramProbability) const {
|
||||||
|
// Not used.
|
||||||
|
return NOT_A_PROBABILITY;
|
||||||
|
}
|
||||||
|
|
||||||
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
|
int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const;
|
||||||
|
|
||||||
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
|
void iterateNgramEntries(const WordIdArrayView prevWordIds,
|
||||||
|
NgramListener *const listener) const;
|
||||||
|
|
||||||
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;
|
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;
|
||||||
|
|
||||||
|
|
|
@ -57,9 +57,9 @@ class IntArrayView {
|
||||||
explicit IntArrayView(const std::vector<int> &vector)
|
explicit IntArrayView(const std::vector<int> &vector)
|
||||||
: mPtr(vector.data()), mSize(vector.size()) {}
|
: mPtr(vector.data()), mSize(vector.size()) {}
|
||||||
|
|
||||||
template <int N>
|
template <size_t N>
|
||||||
AK_FORCE_INLINE static IntArrayView fromFixedSizeArray(const int (&array)[N]) {
|
AK_FORCE_INLINE static IntArrayView fromArray(const std::array<int, N> &array) {
|
||||||
return IntArrayView(array, N);
|
return IntArrayView(array.data(), array.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a view that points one int object.
|
// Returns a view that points one int object.
|
||||||
|
@ -120,6 +120,8 @@ class IntArrayView {
|
||||||
using WordIdArrayView = IntArrayView;
|
using WordIdArrayView = IntArrayView;
|
||||||
using PtNodePosArrayView = IntArrayView;
|
using PtNodePosArrayView = IntArrayView;
|
||||||
using CodePointArrayView = IntArrayView;
|
using CodePointArrayView = IntArrayView;
|
||||||
|
template <size_t size>
|
||||||
|
using WordIdArray = std::array<int, size>;
|
||||||
|
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
#endif // LATINIME_MEMORY_VIEW_H
|
#endif // LATINIME_MEMORY_VIEW_H
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "utils/int_array_view.h"
|
#include "utils/int_array_view.h"
|
||||||
|
@ -97,8 +98,8 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
|
||||||
const int bigramProbability = 20;
|
const int bigramProbability = 20;
|
||||||
const int trigramProbability = 30;
|
const int trigramProbability = 30;
|
||||||
const int wordId = 100;
|
const int wordId = 100;
|
||||||
const int prevWordIdArray[] = { 1, 2 };
|
const std::array<int, 2> prevWordIdArray = {{ 1, 2 }};
|
||||||
const WordIdArrayView prevWordIds = WordIdArrayView::fromFixedSizeArray(prevWordIdArray);
|
const WordIdArrayView prevWordIds = WordIdArrayView::fromArray(prevWordIdArray);
|
||||||
|
|
||||||
const ProbabilityEntry probabilityEntry(flag, probability);
|
const ProbabilityEntry probabilityEntry(flag, probability);
|
||||||
languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
||||||
|
|
|
@ -46,8 +46,8 @@ TEST(IntArrayViewTest, TestIteration) {
|
||||||
|
|
||||||
TEST(IntArrayViewTest, TestConstructFromArray) {
|
TEST(IntArrayViewTest, TestConstructFromArray) {
|
||||||
const size_t ARRAY_SIZE = 100;
|
const size_t ARRAY_SIZE = 100;
|
||||||
int intArray[ARRAY_SIZE];
|
std::array<int, ARRAY_SIZE> intArray;
|
||||||
const auto intArrayView = IntArrayView::fromFixedSizeArray(intArray);
|
const auto intArrayView = IntArrayView::fromArray(intArray);
|
||||||
EXPECT_EQ(ARRAY_SIZE, intArrayView.size());
|
EXPECT_EQ(ARRAY_SIZE, intArrayView.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue