Merge "Use word id for methods related to n-grams."

This commit is contained in:
Keisuke Kuroyanagi 2014-09-03 07:42:02 +00:00 committed by Android (Google) Code Review
commit cc6081c51b
20 changed files with 252 additions and 239 deletions

View file

@ -103,10 +103,10 @@ class DicNode {
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
} }
// Init for root with prevWordsPtNodePos which is used for n-gram // Init for root with prevWordIds which is used for n-gram
void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordsPtNodePos) { void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordIds) {
mIsCachedForNextSuggestion = false; mIsCachedForNextSuggestion = false;
mDicNodeProperties.init(rootPtNodeArrayPos, prevWordsPtNodePos); mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds);
mDicNodeState.init(); mDicNodeState.init();
PROF_NODE_RESET(mProfiler); PROF_NODE_RESET(mProfiler);
} }
@ -114,12 +114,12 @@ class DicNode {
// Init for root with previous word // Init for root with previous word
void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) { void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
int newPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int newPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
newPrevWordsPtNodePos[0] = dicNode->mDicNodeProperties.getPtNodePos(); newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId();
for (size_t i = 1; i < NELEMS(newPrevWordsPtNodePos); ++i) { for (size_t i = 1; i < NELEMS(newPrevWordIds); ++i) {
newPrevWordsPtNodePos[i] = dicNode->getPrevWordsTerminalPtNodePos()[i - 1]; newPrevWordIds[i] = dicNode->getPrevWordIds()[i - 1];
} }
mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordsPtNodePos); mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordIds);
mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
dicNode->mDicNodeProperties.getDepth()); dicNode->mDicNodeProperties.getDepth());
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
@ -145,7 +145,7 @@ class DicNode {
dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount);
mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0], mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0],
probability, wordId, hasChildren, isBlacklistedOrNotAWord, newDepth, probability, wordId, hasChildren, isBlacklistedOrNotAWord, newDepth,
newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordsTerminalPtNodePos()); newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds());
mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount,
mergedNodeCodePoints); mergedNodeCodePoints);
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
@ -204,13 +204,18 @@ class DicNode {
} }
// Used to get n-gram probability in DicNodeUtils. // Used to get n-gram probability in DicNodeUtils.
int getWordId() const {
return mDicNodeProperties.getWordId();
}
// TODO: Remove
int getPtNodePos() const { int getPtNodePos() const {
return mDicNodeProperties.getPtNodePos(); return mDicNodeProperties.getPtNodePos();
} }
// TODO: Use view class to return PtNodePos array. // TODO: Use view class to return word id array.
const int *getPrevWordsTerminalPtNodePos() const { const int *getPrevWordIds() const {
return mDicNodeProperties.getPrevWordsTerminalPtNodePos(); return mDicNodeProperties.getPrevWordIds();
} }
// Used in DicNodeUtils // Used in DicNodeUtils

View file

@ -29,8 +29,8 @@ namespace latinime {
/* static */ void DicNodeUtils::initAsRoot( /* static */ void DicNodeUtils::initAsRoot(
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
const int *const prevWordsPtNodePos, DicNode *const newRootDicNode) { const int *const prevWordIds, DicNode *const newRootDicNode) {
newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordsPtNodePos); newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds);
} }
/*static */ void DicNodeUtils::initAsRootWithPreviousWord( /*static */ void DicNodeUtils::initAsRootWithPreviousWord(
@ -86,9 +86,9 @@ namespace latinime {
const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) { const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) {
const int unigramProbability = dicNode->getProbability(); const int unigramProbability = dicNode->getProbability();
if (multiBigramMap) { if (multiBigramMap) {
const int *const prevWordsPtNodePos = dicNode->getPrevWordsTerminalPtNodePos(); const int *const prevWordIds = dicNode->getPrevWordIds();
return multiBigramMap->getBigramProbability(dictionaryStructurePolicy, return multiBigramMap->getBigramProbability(dictionaryStructurePolicy,
prevWordsPtNodePos, dicNode->getPtNodePos(), unigramProbability); prevWordIds, dicNode->getWordId(), unigramProbability);
} }
return dictionaryStructurePolicy->getProbability(unigramProbability, return dictionaryStructurePolicy->getProbability(unigramProbability,
NOT_A_PROBABILITY); NOT_A_PROBABILITY);

View file

@ -30,7 +30,7 @@ class DicNodeUtils {
public: public:
static void initAsRoot( static void initAsRoot(
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
const int *const prevWordPtNodePos, DicNode *const newRootDicNode); const int *const prevWordIds, DicNode *const newRootDicNode);
static void initAsRootWithPreviousWord( static void initAsRootWithPreviousWord(
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode); const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode);

View file

@ -39,7 +39,7 @@ class DicNodeProperties {
// Should be called only once per DicNode is initialized. // Should be called only once per DicNode is initialized.
void init(const int pos, const int childrenPos, const int nodeCodePoint, const int probability, void init(const int pos, const int childrenPos, const int nodeCodePoint, const int probability,
const int wordId, const bool hasChildren, const bool isBlacklistedOrNotAWord, const int wordId, const bool hasChildren, const bool isBlacklistedOrNotAWord,
const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordsNodePos) { const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordIds) {
mPtNodePos = pos; mPtNodePos = pos;
mChildrenPtNodeArrayPos = childrenPos; mChildrenPtNodeArrayPos = childrenPos;
mDicNodeCodePoint = nodeCodePoint; mDicNodeCodePoint = nodeCodePoint;
@ -49,11 +49,11 @@ class DicNodeProperties {
mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord; mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord;
mDepth = depth; mDepth = depth;
mLeavingDepth = leavingDepth; mLeavingDepth = leavingDepth;
memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos)); memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
} }
// Init for root with prevWordsPtNodePos which is used for n-gram // Init for root with prevWordsPtNodePos which is used for n-gram
void init(const int rootPtNodeArrayPos, const int *const prevWordsNodePos) { void init(const int rootPtNodeArrayPos, const int *const prevWordIds) {
mPtNodePos = NOT_A_DICT_POS; mPtNodePos = NOT_A_DICT_POS;
mChildrenPtNodeArrayPos = rootPtNodeArrayPos; mChildrenPtNodeArrayPos = rootPtNodeArrayPos;
mDicNodeCodePoint = NOT_A_CODE_POINT; mDicNodeCodePoint = NOT_A_CODE_POINT;
@ -63,7 +63,7 @@ class DicNodeProperties {
mIsBlacklistedOrNotAWord = false; mIsBlacklistedOrNotAWord = false;
mDepth = 0; mDepth = 0;
mLeavingDepth = 0; mLeavingDepth = 0;
memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos)); memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
} }
void initByCopy(const DicNodeProperties *const dicNodeProp) { void initByCopy(const DicNodeProperties *const dicNodeProp) {
@ -76,8 +76,7 @@ class DicNodeProperties {
mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord; mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord;
mDepth = dicNodeProp->mDepth; mDepth = dicNodeProp->mDepth;
mLeavingDepth = dicNodeProp->mLeavingDepth; mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos, memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
sizeof(mPrevWordsTerminalPtNodePos));
} }
// Init as passing child // Init as passing child
@ -91,8 +90,7 @@ class DicNodeProperties {
mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord; mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord;
mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child
mLeavingDepth = dicNodeProp->mLeavingDepth; mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos, memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
sizeof(mPrevWordsTerminalPtNodePos));
} }
int getPtNodePos() const { int getPtNodePos() const {
@ -132,8 +130,12 @@ class DicNodeProperties {
return mIsBlacklistedOrNotAWord; return mIsBlacklistedOrNotAWord;
} }
const int *getPrevWordsTerminalPtNodePos() const { const int *getPrevWordIds() const {
return mPrevWordsTerminalPtNodePos; return mPrevWordIds;
}
int getWordId() const {
return mWordId;
} }
private: private:
@ -149,7 +151,7 @@ class DicNodeProperties {
bool mIsBlacklistedOrNotAWord; bool mIsBlacklistedOrNotAWord;
uint16_t mDepth; uint16_t mDepth;
uint16_t mLeavingDepth; uint16_t mLeavingDepth;
int mPrevWordsTerminalPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
}; };
} // namespace latinime } // namespace latinime
#endif // LATINIME_DIC_NODE_PROPERTIES_H #endif // LATINIME_DIC_NODE_PROPERTIES_H

View file

@ -93,11 +93,10 @@ void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
TimeKeeper::setCurrentTime(); TimeKeeper::setCurrentTime();
NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults, NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults,
mDictionaryStructureWithBufferPolicy.get()); mDictionaryStructureWithBufferPolicy.get());
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos( prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos,
true /* tryLowerCaseSearch */); true /* tryLowerCaseSearch */);
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordsPtNodePos, &listener); mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener);
} }
int Dictionary::getProbability(const int *word, int length) const { int Dictionary::getProbability(const int *word, int length) const {
@ -113,18 +112,17 @@ int Dictionary::getMaxProbabilityOfExactMatches(const int *word, int length) con
int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word, int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word,
int length) const { int length) const {
TimeKeeper::setCurrentTime(); TimeKeeper::setCurrentTime();
int nextWordPos = mDictionaryStructureWithBufferPolicy->getTerminalPtNodePositionOfWord( int wordId = mDictionaryStructureWithBufferPolicy->getWordId(
CodePointArrayView(word, length), false /* forceLowerCaseSearch */); CodePointArrayView(word, length), false /* forceLowerCaseSearch */);
if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY; if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY;
if (!prevWordsInfo) { if (!prevWordsInfo) {
return getDictionaryStructurePolicy()->getProbabilityOfPtNode( return getDictionaryStructurePolicy()->getProbabilityOfWord(
nullptr /* prevWordsPtNodePos */, nextWordPos); nullptr /* prevWordsPtNodePos */, wordId);
} }
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos( prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos,
true /* tryLowerCaseSearch */); true /* tryLowerCaseSearch */);
return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsPtNodePos, nextWordPos); return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId);
} }
bool Dictionary::addUnigramEntry(const int *const word, const int length, bool Dictionary::addUnigramEntry(const int *const word, const int length,

View file

@ -34,11 +34,11 @@ namespace latinime {
// No prev words information. // No prev words information.
PrevWordsInfo emptyPrevWordsInfo; PrevWordsInfo emptyPrevWordsInfo;
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
emptyPrevWordsInfo.getPrevWordsTerminalPtNodePos(dictionaryStructurePolicy, emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds,
prevWordsPtNodePos, false /* tryLowerCaseSearch */); false /* tryLowerCaseSearch */);
current.emplace_back(); current.emplace_back();
DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordsPtNodePos, &current.front()); DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, &current.front());
for (int i = 0; i < codePointCount; ++i) { for (int i = 0; i < codePointCount; ++i) {
// The base-lower input is used to ignore case errors and accent errors. // The base-lower input is used to ignore case errors and accent errors.
const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]); const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]);

View file

@ -35,39 +35,37 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP =
// Also caches the bigrams if there is space remaining and they have not been cached already. // Also caches the bigrams if there is space remaining and they have not been cached already.
int MultiBigramMap::getBigramProbability( int MultiBigramMap::getBigramProbability(
const DictionaryStructureWithBufferPolicy *const structurePolicy, const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos, const int nextWordPosition, const int *const prevWordIds, const int nextWordId,
const int unigramProbability) { const int unigramProbability) {
if (!prevWordsPtNodePos || prevWordsPtNodePos[0] == NOT_A_DICT_POS) { if (!prevWordIds || prevWordIds[0] == NOT_A_WORD_ID) {
return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY); return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY);
} }
std::unordered_map<int, BigramMap>::const_iterator mapPosition = const auto mapPosition = mBigramMaps.find(prevWordIds[0]);
mBigramMaps.find(prevWordsPtNodePos[0]);
if (mapPosition != mBigramMaps.end()) { if (mapPosition != mBigramMaps.end()) {
return mapPosition->second.getBigramProbability(structurePolicy, nextWordPosition, return mapPosition->second.getBigramProbability(structurePolicy, nextWordId,
unigramProbability); unigramProbability);
} }
if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) {
addBigramsForWordPosition(structurePolicy, prevWordsPtNodePos); addBigramsForWord(structurePolicy, prevWordIds);
return mBigramMaps[prevWordsPtNodePos[0]].getBigramProbability(structurePolicy, return mBigramMaps[prevWordIds[0]].getBigramProbability(structurePolicy,
nextWordPosition, unigramProbability); nextWordId, unigramProbability);
} }
return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordsPtNodePos, return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordIds,
nextWordPosition, unigramProbability); nextWordId, unigramProbability);
} }
void MultiBigramMap::BigramMap::init( void MultiBigramMap::BigramMap::init(
const DictionaryStructureWithBufferPolicy *const structurePolicy, const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos) { const int *const prevWordIds) {
structurePolicy->iterateNgramEntries(prevWordsPtNodePos, this /* listener */); structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */);
} }
int MultiBigramMap::BigramMap::getBigramProbability( int MultiBigramMap::BigramMap::getBigramProbability(
const DictionaryStructureWithBufferPolicy *const structurePolicy, const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int nextWordPosition, const int unigramProbability) const { const int nextWordId, const int unigramProbability) const {
int bigramProbability = NOT_A_PROBABILITY; int bigramProbability = NOT_A_PROBABILITY;
if (mBloomFilter.isInFilter(nextWordPosition)) { if (mBloomFilter.isInFilter(nextWordId)) {
const std::unordered_map<int, int>::const_iterator bigramProbabilityIt = const auto bigramProbabilityIt = mBigramMap.find(nextWordId);
mBigramMap.find(nextWordPosition);
if (bigramProbabilityIt != mBigramMap.end()) { if (bigramProbabilityIt != mBigramMap.end()) {
bigramProbability = bigramProbabilityIt->second; bigramProbability = bigramProbabilityIt->second;
} }
@ -75,29 +73,27 @@ int MultiBigramMap::BigramMap::getBigramProbability(
return structurePolicy->getProbability(unigramProbability, bigramProbability); return structurePolicy->getProbability(unigramProbability, bigramProbability);
} }
void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, const int targetWordId) {
const int targetPtNodePos) { if (targetWordId == NOT_A_WORD_ID) {
if (targetPtNodePos == NOT_A_DICT_POS) {
return; return;
} }
mBigramMap[targetPtNodePos] = ngramProbability; mBigramMap[targetWordId] = ngramProbability;
mBloomFilter.setInFilter(targetPtNodePos); mBloomFilter.setInFilter(targetWordId);
} }
void MultiBigramMap::addBigramsForWordPosition( void MultiBigramMap::addBigramsForWord(
const DictionaryStructureWithBufferPolicy *const structurePolicy, const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos) { const int *const prevWordIds) {
if (prevWordsPtNodePos) { if (prevWordIds) {
mBigramMaps[prevWordsPtNodePos[0]].init(structurePolicy, prevWordsPtNodePos); mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds);
} }
} }
int MultiBigramMap::readBigramProbabilityFromBinaryDictionary( int MultiBigramMap::readBigramProbabilityFromBinaryDictionary(
const DictionaryStructureWithBufferPolicy *const structurePolicy, const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos, const int nextWordPosition, const int *const prevWordIds, const int nextWordId,
const int unigramProbability) { const int unigramProbability) {
const int bigramProbability = structurePolicy->getProbabilityOfPtNode(prevWordsPtNodePos, const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId);
nextWordPosition);
if (bigramProbability != NOT_A_PROBABILITY) { if (bigramProbability != NOT_A_PROBABILITY) {
return bigramProbability; return bigramProbability;
} }

View file

@ -39,8 +39,7 @@ class MultiBigramMap {
// Look up the bigram probability for the given word pair from the cached bigram maps. // Look up the bigram probability for the given word pair from the cached bigram maps.
// Also caches the bigrams if there is space remaining and they have not been cached already. // Also caches the bigrams if there is space remaining and they have not been cached already.
int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy, int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos, const int nextWordPosition, const int *const prevWordIds, const int nextWordId, const int unigramProbability);
const int unigramProbability);
void clear() { void clear() {
mBigramMaps.clear(); mBigramMaps.clear();
@ -58,11 +57,11 @@ class MultiBigramMap {
virtual ~BigramMap() {} virtual ~BigramMap() {}
void init(const DictionaryStructureWithBufferPolicy *const structurePolicy, void init(const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos); const int *const prevWordIds);
int getBigramProbability( int getBigramProbability(
const DictionaryStructureWithBufferPolicy *const structurePolicy, const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int nextWordPosition, const int unigramProbability) const; const int nextWordId, const int unigramProbability) const;
virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos); virtual void onVisitEntry(const int ngramProbability, const int targetWordId);
private: private:
static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP; static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP;
@ -70,14 +69,12 @@ class MultiBigramMap {
BloomFilter mBloomFilter; BloomFilter mBloomFilter;
}; };
void addBigramsForWordPosition( void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy,
const DictionaryStructureWithBufferPolicy *const structurePolicy, const int *const prevWordIds);
const int *const prevWordsPtNodePos);
int readBigramProbabilityFromBinaryDictionary( int readBigramProbabilityFromBinaryDictionary(
const DictionaryStructureWithBufferPolicy *const structurePolicy, const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos, const int nextWordPosition, const int *const prevWordIds, const int nextWordId, const int unigramProbability);
const int unigramProbability);
static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP;
std::unordered_map<int, BigramMap> mBigramMaps; std::unordered_map<int, BigramMap> mBigramMaps;

View file

@ -26,7 +26,7 @@ namespace latinime {
*/ */
class NgramListener { class NgramListener {
public: public:
virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos) = 0; virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0;
virtual ~NgramListener() {}; virtual ~NgramListener() {};
protected: protected:

View file

@ -53,15 +53,14 @@ class DictionaryStructureWithBufferPolicy {
const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, const int ptNodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const = 0; int *const outUnigramProbability) const = 0;
virtual int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, virtual int getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const = 0; const bool forceLowerCaseSearch) const = 0;
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0; virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
virtual int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0;
const int ptNodePos) const = 0;
virtual void iterateNgramEntries(const int *const prevWordsPtNodePos, virtual void iterateNgramEntries(const int *const prevWordIds,
NgramListener *const listener) const = 0; NgramListener *const listener) const = 0;
virtual int getShortcutPositionOfPtNode(const int ptNodePos) const = 0; virtual int getShortcutPositionOfPtNode(const int ptNodePos) const = 0;

View file

@ -35,8 +35,8 @@ void DicTraverseSession::init(const Dictionary *const dictionary,
mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy() mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy()
->getMultiWordCostMultiplier(); ->getMultiWordCostMultiplier();
mSuggestOptions = suggestOptions; mSuggestOptions = suggestOptions;
prevWordsInfo->getPrevWordsTerminalPtNodePos( prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds,
getDictionaryStructurePolicy(), mPrevWordsPtNodePos, true /* tryLowerCaseSearch */); true /* tryLowerCaseSearch */);
} }
void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo,

View file

@ -55,8 +55,8 @@ class DicTraverseSession {
mMultiWordCostMultiplier(1.0f) { mMultiWordCostMultiplier(1.0f) {
// NOTE: mProximityInfoStates is an array of instances. // NOTE: mProximityInfoStates is an array of instances.
// No need to initialize it explicitly here. // No need to initialize it explicitly here.
for (size_t i = 0; i < NELEMS(mPrevWordsPtNodePos); ++i) { for (size_t i = 0; i < NELEMS(mPrevWordsIds); ++i) {
mPrevWordsPtNodePos[i] = NOT_A_DICT_POS; mPrevWordsIds[i] = NOT_A_DICT_POS;
} }
} }
@ -79,7 +79,7 @@ class DicTraverseSession {
//-------------------- //--------------------
const ProximityInfo *getProximityInfo() const { return mProximityInfo; } const ProximityInfo *getProximityInfo() const { return mProximityInfo; }
const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; } const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; }
const int *getPrevWordsPtNodePos() const { return mPrevWordsPtNodePos; } const int *getPrevWordIds() const { return mPrevWordsIds; }
DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; } MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; }
const ProximityInfoState *getProximityInfoState(int id) const { const ProximityInfoState *getProximityInfoState(int id) const {
@ -166,7 +166,7 @@ class DicTraverseSession {
const int *const inputYs, const int *const times, const int *const pointerIds, const int *const inputYs, const int *const times, const int *const pointerIds,
const int inputSize, const float maxSpatialDistance, const int maxPointerCount); const int inputSize, const float maxSpatialDistance, const int maxPointerCount);
int mPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int mPrevWordsIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
const ProximityInfo *mProximityInfo; const ProximityInfo *mProximityInfo;
const Dictionary *mDictionary; const Dictionary *mDictionary;
const SuggestOptions *mSuggestOptions; const SuggestOptions *mSuggestOptions;

View file

@ -18,14 +18,12 @@
#define LATINIME_PREV_WORDS_INFO_H #define LATINIME_PREV_WORDS_INFO_H
#include "defines.h" #include "defines.h"
#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "utils/char_utils.h" #include "utils/char_utils.h"
#include "utils/int_array_view.h" #include "utils/int_array_view.h"
namespace latinime { namespace latinime {
// TODO: Support n-gram.
class PrevWordsInfo { class PrevWordsInfo {
public: public:
// No prev word information. // No prev word information.
@ -81,11 +79,10 @@ class PrevWordsInfo {
return false; return false;
} }
void getPrevWordsTerminalPtNodePos( void getPrevWordIds(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, int *const outPrevWordIds, const bool tryLowerCaseSearch) const {
int *const outPrevWordsTerminalPtNodePos, const bool tryLowerCaseSearch) const {
for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
outPrevWordsTerminalPtNodePos[i] = getTerminalPtNodePosOfWord(dictStructurePolicy, outPrevWordIds[i] = getWordId(dictStructurePolicy,
mPrevWordCodePoints[i], mPrevWordCodePointCount[i], mPrevWordCodePoints[i], mPrevWordCodePointCount[i],
mIsBeginningOfSentence[i], tryLowerCaseSearch); mIsBeginningOfSentence[i], tryLowerCaseSearch);
} }
@ -110,12 +107,11 @@ class PrevWordsInfo {
private: private:
DISALLOW_COPY_AND_ASSIGN(PrevWordsInfo); DISALLOW_COPY_AND_ASSIGN(PrevWordsInfo);
static int getTerminalPtNodePosOfWord( static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
const int *const wordCodePoints, const int wordCodePointCount, const int *const wordCodePoints, const int wordCodePointCount,
const bool isBeginningOfSentence, const bool tryLowerCaseSearch) { const bool isBeginningOfSentence, const bool tryLowerCaseSearch) {
if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) {
return NOT_A_DICT_POS; return NOT_A_WORD_ID;
} }
int codePoints[MAX_WORD_LENGTH]; int codePoints[MAX_WORD_LENGTH];
int codePointCount = wordCodePointCount; int codePointCount = wordCodePointCount;
@ -124,21 +120,19 @@ class PrevWordsInfo {
codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints,
codePointCount, MAX_WORD_LENGTH); codePointCount, MAX_WORD_LENGTH);
if (codePointCount <= 0) { if (codePointCount <= 0) {
return NOT_A_DICT_POS; return NOT_A_WORD_ID;
} }
} }
const CodePointArrayView codePointArrayView(codePoints, codePointCount); const CodePointArrayView codePointArrayView(codePoints, codePointCount);
const int wordPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord( const int wordId = dictStructurePolicy->getWordId(
codePointArrayView, false /* forceLowerCaseSearch */); codePointArrayView, false /* forceLowerCaseSearch */);
if (wordPtNodePos != NOT_A_DICT_POS || !tryLowerCaseSearch) { if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) {
// Return the position when when the word was found or doesn't try lower case // Return the id when when the word was found or doesn't try lower case search.
// search. return wordId;
return wordPtNodePos;
} }
// Check bigrams for lower-cased previous word if original was not found. Useful for // Check bigrams for lower-cased previous word if original was not found. Useful for
// auto-capitalized words like "The [current_word]". // auto-capitalized words like "The [current_word]".
return dictStructurePolicy->getTerminalPtNodePositionOfWord( return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */);
codePointArrayView, true /* forceLowerCaseSearch */);
} }
void clear() { void clear() {

View file

@ -92,7 +92,7 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession) const {
// Create a new dic node here // Create a new dic node here
DicNode rootNode; DicNode rootNode;
DicNodeUtils::initAsRoot(traverseSession->getDictionaryStructurePolicy(), DicNodeUtils::initAsRoot(traverseSession->getDictionaryStructurePolicy(),
traverseSession->getPrevWordsPtNodePos(), &rootNode); traverseSession->getPrevWordIds(), &rootNode);
traverseSession->getDicTraverseCache()->copyPushActive(&rootNode); traverseSession->getDicTraverseCache()->copyPushActive(&rootNode);
} }
} }

View file

@ -104,7 +104,7 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
return codePointCount; return codePointCount;
} }
int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const { const bool forceLowerCaseSearch) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition()); readingHelper.initWithPtNodeArrayPos(getRootPosition());
@ -112,9 +112,9 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArray
wordCodePoints.size(), forceLowerCaseSearch); wordCodePoints.size(), forceLowerCaseSearch);
if (readingHelper.isError()) { if (readingHelper.isError()) {
mIsCorrupted = true; mIsCorrupted = true;
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); AKLOGE("Dictionary reading error in getWordId().");
} }
return ptNodePos; return getWordIdFromTerminalPtNodePos(ptNodePos);
} }
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
@ -133,17 +133,19 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
} }
} }
int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
const int ptNodePos) const { const int wordId) const {
if (ptNodePos == NOT_A_DICT_POS) { if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
if (prevWordsPtNodePos) { if (prevWordIds) {
const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
while (bigramsIt.hasNext()) { while (bigramsIt.hasNext()) {
bigramsIt.next(); bigramsIt.next();
@ -157,16 +159,18 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtN
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
} }
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
NgramListener *const listener) const { NgramListener *const listener) const {
if (!prevWordsPtNodePos) { if (!prevWordIds) {
return; return;
} }
const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
while (bigramsIt.hasNext()) { while (bigramsIt.hasNext()) {
bigramsIt.next(); bigramsIt.next();
listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); listener->onVisitEntry(bigramsIt.getProbability(),
getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()));
} }
} }
@ -238,8 +242,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
} }
if (unigramProperty->getShortcuts().size() > 0) { if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target. // Add shortcut target.
const int wordPos = getTerminalPtNodePositionOfWord(codePointArrayView, const int wordPos = getTerminalPtNodePosFromWordId(
false /* forceLowerCaseSearch */); getWordId(codePointArrayView, false /* forceLowerCaseSearch */));
if (wordPos == NOT_A_DICT_POS) { if (wordPos == NOT_A_DICT_POS) {
AKLOGE("Cannot find terminal PtNode position to add shortcut target."); AKLOGE("Cannot find terminal PtNode position to add shortcut target.");
return false; return false;
@ -266,8 +270,8 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary.");
return false; return false;
} }
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, const int ptNodePos = getTerminalPtNodePosFromWordId(
false /* forceLowerCaseSearch */); getWordId(wordCodePoints, false /* forceLowerCaseSearch */));
if (ptNodePos == NOT_A_DICT_POS) { if (ptNodePos == NOT_A_DICT_POS) {
return false; return false;
} }
@ -295,11 +299,9 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
"length: %zd", bigramProperty->getTargetCodePoints()->size()); "length: %zd", bigramProperty->getTargetCodePoints()->size());
return false; return false;
} }
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
false /* tryLowerCaseSearch */); if (prevWordIds[0] == NOT_A_WORD_ID) {
// TODO: Support N-gram.
if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) {
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
const std::vector<UnigramProperty::ShortcutProperty> shortcuts; const std::vector<UnigramProperty::ShortcutProperty> shortcuts;
const UnigramProperty beginningOfSentenceUnigramProperty( const UnigramProperty beginningOfSentenceUnigramProperty(
@ -311,22 +313,22 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
return false; return false;
} }
// Refresh Terminal PtNode positions. // Refresh word ids.
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
false /* tryLowerCaseSearch */);
} else { } else {
return false; return false;
} }
} }
const int word1Pos = getTerminalPtNodePositionOfWord( const int wordPos = getTerminalPtNodePosFromWordId(getWordId(
CodePointArrayView(*bigramProperty->getTargetCodePoints()), CodePointArrayView(*bigramProperty->getTargetCodePoints()),
false /* forceLowerCaseSearch */); false /* forceLowerCaseSearch */));
if (word1Pos == NOT_A_DICT_POS) { if (wordPos == NOT_A_DICT_POS) {
return false; return false;
} }
bool addedNewBigram = false; bool addedNewBigram = false;
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(prevWordsPtNodePos), const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
word1Pos, bigramProperty, &addedNewBigram)) { if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos),
wordPos, bigramProperty, &addedNewBigram)) {
if (addedNewBigram) { if (addedNewBigram) {
mBigramCount++; mBigramCount++;
} }
@ -355,20 +357,19 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
wordCodePoints.size()); wordCodePoints.size());
} }
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */);
false /* tryLowerCaseSerch */); if (prevWordIds[0] == NOT_A_WORD_ID) {
// TODO: Support N-gram.
if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) {
return false; return false;
} }
const int wordPos = getTerminalPtNodePositionOfWord(wordCodePoints, const int wordPos = getTerminalPtNodePosFromWordId(getWordId(wordCodePoints,
false /* forceLowerCaseSearch */); false /* forceLowerCaseSearch */));
if (wordPos == NOT_A_DICT_POS) { if (wordPos == NOT_A_DICT_POS) {
return false; return false;
} }
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry( if (mUpdatingHelper.removeNgramEntry(
PtNodePosArrayView::fromObject(prevWordsPtNodePos), wordPos)) { PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) {
mBigramCount--; mBigramCount--;
return true; return true;
} else { } else {
@ -449,8 +450,8 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
const WordProperty Ver4PatriciaTriePolicy::getWordProperty( const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
const CodePointArrayView wordCodePoints) const { const CodePointArrayView wordCodePoints) const {
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, const int ptNodePos = getTerminalPtNodePosFromWordId(
false /* forceLowerCaseSearch */); getWordId(wordCodePoints, false /* forceLowerCaseSearch */));
if (ptNodePos == NOT_A_DICT_POS) { if (ptNodePos == NOT_A_DICT_POS) {
AKLOGE("getWordProperty is called for invalid word."); AKLOGE("getWordProperty is called for invalid word.");
return WordProperty(); return WordProperty();
@ -553,6 +554,14 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
return nextToken; return nextToken;
} }
int Ver4PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const {
return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos;
}
int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const {
return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId;
}
} // namespace v402 } // namespace v402
} // namespace backward } // namespace backward
} // namespace latinime } // namespace latinime

View file

@ -87,15 +87,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const; int *const outUnigramProbability) const;
int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
const bool forceLowerCaseSearch) const;
int getProbability(const int unigramProbability, const int bigramProbability) const; int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
void iterateNgramEntries(const int *const prevWordsPtNodePos, void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
NgramListener *const listener) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const; int getShortcutPositionOfPtNode(const int ptNodePos) const;
@ -164,6 +162,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
mutable bool mIsCorrupted; mutable bool mIsCorrupted;
int getBigramsPositionOfPtNode(const int ptNodePos) const; int getBigramsPositionOfPtNode(const int ptNodePos) const;
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
int getTerminalPtNodePosFromWordId(const int wordId) const;
}; };
} // namespace v402 } // namespace v402
} // namespace backward } // namespace backward

View file

@ -267,8 +267,8 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
} }
// This function gets the position of the terminal PtNode of the exact matching word in the // This function gets the position of the terminal PtNode of the exact matching word in the
// dictionary. If no match is found, it returns NOT_A_DICT_POS. // dictionary. If no match is found, it returns NOT_A_WORD_ID.
int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const { const bool forceLowerCaseSearch) const {
DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader); DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition()); readingHelper.initWithPtNodeArrayPos(getRootPosition());
@ -276,9 +276,9 @@ int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView
wordCodePoints.size(), forceLowerCaseSearch); wordCodePoints.size(), forceLowerCaseSearch);
if (readingHelper.isError()) { if (readingHelper.isError()) {
mIsCorrupted = true; mIsCorrupted = true;
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); AKLOGE("Dictionary reading error in getWordId().");
} }
return ptNodePos; return getWordIdFromTerminalPtNodePos(ptNodePos);
} }
int PatriciaTriePolicy::getProbability(const int unigramProbability, int PatriciaTriePolicy::getProbability(const int unigramProbability,
@ -297,11 +297,11 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability,
} }
} }
int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const int wordId) const {
const int ptNodePos) const { if (wordId == NOT_A_WORD_ID) {
if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams = const PtNodeParams ptNodeParams =
mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) { if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) {
@ -310,8 +310,9 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP
// for shortcuts). // for shortcuts).
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
if (prevWordsPtNodePos) { if (prevWordIds) {
const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition);
while (bigramsIt.hasNext()) { while (bigramsIt.hasNext()) {
bigramsIt.next(); bigramsIt.next();
@ -325,16 +326,18 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
} }
void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
NgramListener *const listener) const { NgramListener *const listener) const {
if (!prevWordsPtNodePos) { if (!prevWordIds) {
return; return;
} }
const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition);
while (bigramsIt.hasNext()) { while (bigramsIt.hasNext()) {
bigramsIt.next(); bigramsIt.next();
listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); listener->onVisitEntry(bigramsIt.getProbability(),
getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()));
} }
} }
@ -379,12 +382,12 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod
const WordProperty PatriciaTriePolicy::getWordProperty( const WordProperty PatriciaTriePolicy::getWordProperty(
const CodePointArrayView wordCodePoints) const { const CodePointArrayView wordCodePoints) const {
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);
false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) {
if (ptNodePos == NOT_A_DICT_POS) {
AKLOGE("getWordProperty was called for invalid word."); AKLOGE("getWordProperty was called for invalid word.");
return WordProperty(); return WordProperty();
} }
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams = const PtNodeParams ptNodeParams =
mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
std::vector<int> codePointVector(ptNodeParams.getCodePoints(), std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
@ -467,4 +470,11 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC
return nextToken; return nextToken;
} }
int PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const {
return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos;
}
int PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const {
return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId;
}
} // namespace latinime } // namespace latinime

View file

@ -64,15 +64,13 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints, const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const; int *const outUnigramProbability) const;
int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
const bool forceLowerCaseSearch) const;
int getProbability(const int unigramProbability, const int bigramProbability) const; int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
void iterateNgramEntries(const int *const prevWordsPtNodePos, void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
NgramListener *const listener) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const; int getShortcutPositionOfPtNode(const int ptNodePos) const;
@ -163,6 +161,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getBigramsPositionOfPtNode(const int ptNodePos) const; int getBigramsPositionOfPtNode(const int ptNodePos) const;
int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos, int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos,
DicNodeVector *const childDicNodes) const; DicNodeVector *const childDicNodes) const;
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
int getTerminalPtNodePosFromWordId(const int wordId) const;
}; };
} // namespace latinime } // namespace latinime
#endif // LATINIME_PATRICIA_TRIE_POLICY_H #endif // LATINIME_PATRICIA_TRIE_POLICY_H

View file

@ -94,7 +94,7 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
return codePointCount; return codePointCount;
} }
int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const { const bool forceLowerCaseSearch) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition()); readingHelper.initWithPtNodeArrayPos(getRootPosition());
@ -104,7 +104,11 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArray
mIsCorrupted = true; mIsCorrupted = true;
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
} }
return ptNodePos; if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_WORD_ID;
}
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
return ptNodeParams.getTerminalId();
} }
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
@ -123,24 +127,22 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
} }
} }
int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
const int ptNodePos) const { const int wordId) const {
if (ptNodePos == NOT_A_DICT_POS) { if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
if (prevWordsPtNodePos) { if (prevWordIds) {
// TODO: Support n-gram. // TODO: Support n-gram.
const PtNodeParams prevWordPtNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordsPtNodePos[0]);
const int prevWordTerminalId = prevWordPtNodeParams.getTerminalId();
const ProbabilityEntry probabilityEntry = const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry( mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
IntArrayView::fromObject(&prevWordTerminalId), IntArrayView::fromObject(prevWordIds), wordId);
ptNodeParams.getTerminalId());
if (!probabilityEntry.isValid()) { if (!probabilityEntry.isValid()) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
@ -154,26 +156,21 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtN
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
} }
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
NgramListener *const listener) const { NgramListener *const listener) const {
if (!prevWordsPtNodePos) { if (!prevWordIds) {
return; return;
} }
// TODO: Support n-gram. // TODO: Support n-gram.
const PtNodeParams ptNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordsPtNodePos[0]);
const int prevWordId = ptNodeParams.getTerminalId();
const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&prevWordId);
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent(); const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
for (const auto entry : languageModelDictContent->getProbabilityEntries(prevWordIds)) { for (const auto entry : languageModelDictContent->getProbabilityEntries(
WordIdArrayView::fromObject(prevWordIds))) {
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
const int probability = probabilityEntry.hasHistoricalInfo() ? const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability( ForgettingCurveUtils::decodeProbability(
probabilityEntry.getHistoricalInfo(), mHeaderPolicy) : probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
probabilityEntry.getProbability(); probabilityEntry.getProbability();
const int ptNodePos = mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition( listener->onVisitEntry(probability, entry.getWordId());
entry.getWordId());
listener->onVisitEntry(probability, ptNodePos);
} }
} }
@ -233,12 +230,13 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
} }
if (unigramProperty->getShortcuts().size() > 0) { if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target. // Add shortcut target.
const int wordPos = getTerminalPtNodePositionOfWord(codePointArrayView, const int wordId = getWordId(codePointArrayView, false /* forceLowerCaseSearch */);
false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) {
if (wordPos == NOT_A_DICT_POS) { AKLOGE("Cannot find word id to add shortcut target.");
AKLOGE("Cannot find terminal PtNode position to add shortcut target.");
return false; return false;
} }
const int wordPos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
for (const auto &shortcut : unigramProperty->getShortcuts()) { for (const auto &shortcut : unigramProperty->getShortcuts()) {
if (!mUpdatingHelper.addShortcutTarget(wordPos, if (!mUpdatingHelper.addShortcutTarget(wordPos,
shortcut.getTargetCodePoints()->data(), shortcut.getTargetCodePoints()->data(),
@ -261,20 +259,19 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary.");
return false; return false;
} }
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);
false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) {
if (ptNodePos == NOT_A_DICT_POS) {
return false; return false;
} }
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
if (!mNodeWriter.markPtNodeAsDeleted(&ptNodeParams)) { if (!mNodeWriter.markPtNodeAsDeleted(&ptNodeParams)) {
AKLOGE("Cannot remove unigram. ptNodePos: %d", ptNodePos); AKLOGE("Cannot remove unigram. ptNodePos: %d", ptNodePos);
return false; return false;
} }
if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry( if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry(wordId)) {
ptNodeParams.getTerminalId())) { return false;
// TODO: Uncomment.
// return false;
} }
if (!ptNodeParams.representsNonWordInfo()) { if (!ptNodeParams.representsNonWordInfo()) {
mUnigramCount--; mUnigramCount--;
@ -302,12 +299,10 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
"length: %zd", bigramProperty->getTargetCodePoints()->size()); "length: %zd", bigramProperty->getTargetCodePoints()->size());
return false; return false;
} }
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
false /* tryLowerCaseSearch */);
const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos);
// TODO: Support N-gram. // TODO: Support N-gram.
if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { if (prevWordIds[0] == NOT_A_WORD_ID) {
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
const std::vector<UnigramProperty::ShortcutProperty> shortcuts; const std::vector<UnigramProperty::ShortcutProperty> shortcuts;
const UnigramProperty beginningOfSentenceUnigramProperty( const UnigramProperty beginningOfSentenceUnigramProperty(
@ -319,22 +314,27 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
return false; return false;
} }
// Refresh Terminal PtNode positions. // Refresh word ids.
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
false /* tryLowerCaseSearch */);
} else { } else {
return false; return false;
} }
} }
const int word1Pos = getTerminalPtNodePositionOfWord( const int wordId = getWordId(CodePointArrayView(*bigramProperty->getTargetCodePoints()),
CodePointArrayView(*bigramProperty->getTargetCodePoints()),
false /* forceLowerCaseSearch */); false /* forceLowerCaseSearch */);
if (word1Pos == NOT_A_DICT_POS) { if (wordId == NOT_A_WORD_ID) {
return false; return false;
} }
bool addedNewEntry = false; bool addedNewEntry = false;
if (mUpdatingHelper.addNgramEntry(prevWordsPtNodePosView, word1Pos, bigramProperty, int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
&addedNewEntry)) { for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(prevWordIds[i]);
}
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(wordId);
if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
wordPtNodePos, bigramProperty, &addedNewEntry)) {
if (addedNewEntry) { if (addedNewEntry) {
mBigramCount++; mBigramCount++;
} }
@ -363,20 +363,25 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
wordCodePoints.size()); wordCodePoints.size());
} }
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */);
false /* tryLowerCaseSerch */);
const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos);
// TODO: Support N-gram. // TODO: Support N-gram.
if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { if (prevWordIds[0] == NOT_A_WORD_ID) {
return false; return false;
} }
const int wordPos = getTerminalPtNodePositionOfWord(wordCodePoints, const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);
false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) {
if (wordPos == NOT_A_DICT_POS) {
return false; return false;
} }
if (mUpdatingHelper.removeNgramEntry(prevWordsPtNodePosView, wordPos)) { int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(prevWordIds[i]);
}
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(wordId);
if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
wordPtNodePos)) {
mBigramCount--; mBigramCount--;
return true; return true;
} else { } else {
@ -457,12 +462,13 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
const WordProperty Ver4PatriciaTriePolicy::getWordProperty( const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
const CodePointArrayView wordCodePoints) const { const CodePointArrayView wordCodePoints) const {
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints, const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);
false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) {
if (ptNodePos == NOT_A_DICT_POS) {
AKLOGE("getWordProperty is called for invalid word."); AKLOGE("getWordProperty is called for invalid word.");
return WordProperty(); return WordProperty();
} }
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
std::vector<int> codePointVector(ptNodeParams.getCodePoints(), std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
@ -473,7 +479,6 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
// Fetch bigram information. // Fetch bigram information.
// TODO: Support n-gram. // TODO: Support n-gram.
std::vector<BigramProperty> bigrams; std::vector<BigramProperty> bigrams;
const int wordId = ptNodeParams.getTerminalId();
const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId); const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId);
const TerminalPositionLookupTable *const terminalPositionLookupTable = const TerminalPositionLookupTable *const terminalPositionLookupTable =
mBuffers->getTerminalPositionLookupTable(); mBuffers->getTerminalPositionLookupTable();

View file

@ -66,15 +66,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const; int *const outUnigramProbability) const;
int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints, int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
const bool forceLowerCaseSearch) const;
int getProbability(const int unigramProbability, const int bigramProbability) const; int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
void iterateNgramEntries(const int *const prevWordsPtNodePos, void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
NgramListener *const listener) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const; int getShortcutPositionOfPtNode(const int ptNodePos) const;