Merge "Use word id for methods related to n-grams."

main
Keisuke Kuroyanagi 2014-09-03 07:42:02 +00:00 committed by Android (Google) Code Review
commit cc6081c51b
20 changed files with 252 additions and 239 deletions

View File

@ -103,10 +103,10 @@ class DicNode {
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
}
// Init for root with prevWordsPtNodePos which is used for n-gram
void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordsPtNodePos) {
// Init for root with prevWordIds which is used for n-gram
void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordIds) {
mIsCachedForNextSuggestion = false;
mDicNodeProperties.init(rootPtNodeArrayPos, prevWordsPtNodePos);
mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds);
mDicNodeState.init();
PROF_NODE_RESET(mProfiler);
}
@ -114,12 +114,12 @@ class DicNode {
// Init for root with previous word
void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
int newPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
newPrevWordsPtNodePos[0] = dicNode->mDicNodeProperties.getPtNodePos();
for (size_t i = 1; i < NELEMS(newPrevWordsPtNodePos); ++i) {
newPrevWordsPtNodePos[i] = dicNode->getPrevWordsTerminalPtNodePos()[i - 1];
int newPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId();
for (size_t i = 1; i < NELEMS(newPrevWordIds); ++i) {
newPrevWordIds[i] = dicNode->getPrevWordIds()[i - 1];
}
mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordsPtNodePos);
mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordIds);
mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
dicNode->mDicNodeProperties.getDepth());
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
@ -145,7 +145,7 @@ class DicNode {
dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount);
mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0],
probability, wordId, hasChildren, isBlacklistedOrNotAWord, newDepth,
newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordsTerminalPtNodePos());
newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds());
mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount,
mergedNodeCodePoints);
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
@ -204,13 +204,18 @@ class DicNode {
}
// Used to get n-gram probability in DicNodeUtils.
int getWordId() const {
return mDicNodeProperties.getWordId();
}
// TODO: Remove
int getPtNodePos() const {
return mDicNodeProperties.getPtNodePos();
}
// TODO: Use view class to return PtNodePos array.
const int *getPrevWordsTerminalPtNodePos() const {
return mDicNodeProperties.getPrevWordsTerminalPtNodePos();
// TODO: Use view class to return word id array.
const int *getPrevWordIds() const {
return mDicNodeProperties.getPrevWordIds();
}
// Used in DicNodeUtils

View File

@ -29,8 +29,8 @@ namespace latinime {
/* static */ void DicNodeUtils::initAsRoot(
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
const int *const prevWordsPtNodePos, DicNode *const newRootDicNode) {
newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordsPtNodePos);
const int *const prevWordIds, DicNode *const newRootDicNode) {
newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds);
}
/*static */ void DicNodeUtils::initAsRootWithPreviousWord(
@ -86,9 +86,9 @@ namespace latinime {
const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) {
const int unigramProbability = dicNode->getProbability();
if (multiBigramMap) {
const int *const prevWordsPtNodePos = dicNode->getPrevWordsTerminalPtNodePos();
const int *const prevWordIds = dicNode->getPrevWordIds();
return multiBigramMap->getBigramProbability(dictionaryStructurePolicy,
prevWordsPtNodePos, dicNode->getPtNodePos(), unigramProbability);
prevWordIds, dicNode->getWordId(), unigramProbability);
}
return dictionaryStructurePolicy->getProbability(unigramProbability,
NOT_A_PROBABILITY);

View File

@ -30,7 +30,7 @@ class DicNodeUtils {
public:
static void initAsRoot(
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
const int *const prevWordPtNodePos, DicNode *const newRootDicNode);
const int *const prevWordIds, DicNode *const newRootDicNode);
static void initAsRootWithPreviousWord(
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode);

View File

@ -39,7 +39,7 @@ class DicNodeProperties {
// Should be called only once per DicNode is initialized.
void init(const int pos, const int childrenPos, const int nodeCodePoint, const int probability,
const int wordId, const bool hasChildren, const bool isBlacklistedOrNotAWord,
const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordsNodePos) {
const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordIds) {
mPtNodePos = pos;
mChildrenPtNodeArrayPos = childrenPos;
mDicNodeCodePoint = nodeCodePoint;
@ -49,11 +49,11 @@ class DicNodeProperties {
mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord;
mDepth = depth;
mLeavingDepth = leavingDepth;
memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos));
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
}
// Init for root with prevWordsPtNodePos which is used for n-gram
void init(const int rootPtNodeArrayPos, const int *const prevWordsNodePos) {
void init(const int rootPtNodeArrayPos, const int *const prevWordIds) {
mPtNodePos = NOT_A_DICT_POS;
mChildrenPtNodeArrayPos = rootPtNodeArrayPos;
mDicNodeCodePoint = NOT_A_CODE_POINT;
@ -63,7 +63,7 @@ class DicNodeProperties {
mIsBlacklistedOrNotAWord = false;
mDepth = 0;
mLeavingDepth = 0;
memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos));
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
}
void initByCopy(const DicNodeProperties *const dicNodeProp) {
@ -76,8 +76,7 @@ class DicNodeProperties {
mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord;
mDepth = dicNodeProp->mDepth;
mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos,
sizeof(mPrevWordsTerminalPtNodePos));
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
}
// Init as passing child
@ -91,8 +90,7 @@ class DicNodeProperties {
mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord;
mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child
mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos,
sizeof(mPrevWordsTerminalPtNodePos));
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
}
int getPtNodePos() const {
@ -132,8 +130,12 @@ class DicNodeProperties {
return mIsBlacklistedOrNotAWord;
}
const int *getPrevWordsTerminalPtNodePos() const {
return mPrevWordsTerminalPtNodePos;
const int *getPrevWordIds() const {
return mPrevWordIds;
}
int getWordId() const {
return mWordId;
}
private:
@ -149,7 +151,7 @@ class DicNodeProperties {
bool mIsBlacklistedOrNotAWord;
uint16_t mDepth;
uint16_t mLeavingDepth;
int mPrevWordsTerminalPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_PROPERTIES_H

View File

@ -93,11 +93,10 @@ void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
TimeKeeper::setCurrentTime();
NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults,
mDictionaryStructureWithBufferPolicy.get());
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(
mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos,
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
true /* tryLowerCaseSearch */);
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordsPtNodePos, &listener);
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener);
}
int Dictionary::getProbability(const int *word, int length) const {
@ -113,18 +112,17 @@ int Dictionary::getMaxProbabilityOfExactMatches(const int *word, int length) con
int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word,
int length) const {
TimeKeeper::setCurrentTime();
int nextWordPos = mDictionaryStructureWithBufferPolicy->getTerminalPtNodePositionOfWord(
int wordId = mDictionaryStructureWithBufferPolicy->getWordId(
CodePointArrayView(word, length), false /* forceLowerCaseSearch */);
if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY;
if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY;
if (!prevWordsInfo) {
return getDictionaryStructurePolicy()->getProbabilityOfPtNode(
nullptr /* prevWordsPtNodePos */, nextWordPos);
return getDictionaryStructurePolicy()->getProbabilityOfWord(
nullptr /* prevWordsPtNodePos */, wordId);
}
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(
mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos,
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
true /* tryLowerCaseSearch */);
return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsPtNodePos, nextWordPos);
return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId);
}
bool Dictionary::addUnigramEntry(const int *const word, const int length,

View File

@ -34,11 +34,11 @@ namespace latinime {
// No prev words information.
PrevWordsInfo emptyPrevWordsInfo;
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
emptyPrevWordsInfo.getPrevWordsTerminalPtNodePos(dictionaryStructurePolicy,
prevWordsPtNodePos, false /* tryLowerCaseSearch */);
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds,
false /* tryLowerCaseSearch */);
current.emplace_back();
DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordsPtNodePos, &current.front());
DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, &current.front());
for (int i = 0; i < codePointCount; ++i) {
// The base-lower input is used to ignore case errors and accent errors.
const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]);

View File

@ -35,39 +35,37 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP =
// Also caches the bigrams if there is space remaining and they have not been cached already.
int MultiBigramMap::getBigramProbability(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos, const int nextWordPosition,
const int *const prevWordIds, const int nextWordId,
const int unigramProbability) {
if (!prevWordsPtNodePos || prevWordsPtNodePos[0] == NOT_A_DICT_POS) {
if (!prevWordIds || prevWordIds[0] == NOT_A_WORD_ID) {
return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY);
}
std::unordered_map<int, BigramMap>::const_iterator mapPosition =
mBigramMaps.find(prevWordsPtNodePos[0]);
const auto mapPosition = mBigramMaps.find(prevWordIds[0]);
if (mapPosition != mBigramMaps.end()) {
return mapPosition->second.getBigramProbability(structurePolicy, nextWordPosition,
return mapPosition->second.getBigramProbability(structurePolicy, nextWordId,
unigramProbability);
}
if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) {
addBigramsForWordPosition(structurePolicy, prevWordsPtNodePos);
return mBigramMaps[prevWordsPtNodePos[0]].getBigramProbability(structurePolicy,
nextWordPosition, unigramProbability);
addBigramsForWord(structurePolicy, prevWordIds);
return mBigramMaps[prevWordIds[0]].getBigramProbability(structurePolicy,
nextWordId, unigramProbability);
}
return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordsPtNodePos,
nextWordPosition, unigramProbability);
return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordIds,
nextWordId, unigramProbability);
}
void MultiBigramMap::BigramMap::init(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos) {
structurePolicy->iterateNgramEntries(prevWordsPtNodePos, this /* listener */);
const int *const prevWordIds) {
structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */);
}
int MultiBigramMap::BigramMap::getBigramProbability(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int nextWordPosition, const int unigramProbability) const {
const int nextWordId, const int unigramProbability) const {
int bigramProbability = NOT_A_PROBABILITY;
if (mBloomFilter.isInFilter(nextWordPosition)) {
const std::unordered_map<int, int>::const_iterator bigramProbabilityIt =
mBigramMap.find(nextWordPosition);
if (mBloomFilter.isInFilter(nextWordId)) {
const auto bigramProbabilityIt = mBigramMap.find(nextWordId);
if (bigramProbabilityIt != mBigramMap.end()) {
bigramProbability = bigramProbabilityIt->second;
}
@ -75,29 +73,27 @@ int MultiBigramMap::BigramMap::getBigramProbability(
return structurePolicy->getProbability(unigramProbability, bigramProbability);
}
void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability,
const int targetPtNodePos) {
if (targetPtNodePos == NOT_A_DICT_POS) {
void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, const int targetWordId) {
if (targetWordId == NOT_A_WORD_ID) {
return;
}
mBigramMap[targetPtNodePos] = ngramProbability;
mBloomFilter.setInFilter(targetPtNodePos);
mBigramMap[targetWordId] = ngramProbability;
mBloomFilter.setInFilter(targetWordId);
}
void MultiBigramMap::addBigramsForWordPosition(
void MultiBigramMap::addBigramsForWord(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos) {
if (prevWordsPtNodePos) {
mBigramMaps[prevWordsPtNodePos[0]].init(structurePolicy, prevWordsPtNodePos);
const int *const prevWordIds) {
if (prevWordIds) {
mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds);
}
}
int MultiBigramMap::readBigramProbabilityFromBinaryDictionary(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos, const int nextWordPosition,
const int *const prevWordIds, const int nextWordId,
const int unigramProbability) {
const int bigramProbability = structurePolicy->getProbabilityOfPtNode(prevWordsPtNodePos,
nextWordPosition);
const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId);
if (bigramProbability != NOT_A_PROBABILITY) {
return bigramProbability;
}

View File

@ -39,8 +39,7 @@ class MultiBigramMap {
// Look up the bigram probability for the given word pair from the cached bigram maps.
// Also caches the bigrams if there is space remaining and they have not been cached already.
int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos, const int nextWordPosition,
const int unigramProbability);
const int *const prevWordIds, const int nextWordId, const int unigramProbability);
void clear() {
mBigramMaps.clear();
@ -58,11 +57,11 @@ class MultiBigramMap {
virtual ~BigramMap() {}
void init(const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos);
const int *const prevWordIds);
int getBigramProbability(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int nextWordPosition, const int unigramProbability) const;
virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos);
const int nextWordId, const int unigramProbability) const;
virtual void onVisitEntry(const int ngramProbability, const int targetWordId);
private:
static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP;
@ -70,14 +69,12 @@ class MultiBigramMap {
BloomFilter mBloomFilter;
};
void addBigramsForWordPosition(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos);
void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordIds);
int readBigramProbabilityFromBinaryDictionary(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordsPtNodePos, const int nextWordPosition,
const int unigramProbability);
const int *const prevWordIds, const int nextWordId, const int unigramProbability);
static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP;
std::unordered_map<int, BigramMap> mBigramMaps;

View File

@ -26,7 +26,7 @@ namespace latinime {
*/
class NgramListener {
public:
virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos) = 0;
virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0;
virtual ~NgramListener() {};
protected:

View File

@ -53,15 +53,14 @@ class DictionaryStructureWithBufferPolicy {
const int ptNodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const = 0;
virtual int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints,
virtual int getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const = 0;
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
virtual int getProbabilityOfPtNode(const int *const prevWordsPtNodePos,
const int ptNodePos) const = 0;
virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0;
virtual void iterateNgramEntries(const int *const prevWordsPtNodePos,
virtual void iterateNgramEntries(const int *const prevWordIds,
NgramListener *const listener) const = 0;
virtual int getShortcutPositionOfPtNode(const int ptNodePos) const = 0;

View File

@ -35,8 +35,8 @@ void DicTraverseSession::init(const Dictionary *const dictionary,
mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy()
->getMultiWordCostMultiplier();
mSuggestOptions = suggestOptions;
prevWordsInfo->getPrevWordsTerminalPtNodePos(
getDictionaryStructurePolicy(), mPrevWordsPtNodePos, true /* tryLowerCaseSearch */);
prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds,
true /* tryLowerCaseSearch */);
}
void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo,

View File

@ -55,8 +55,8 @@ class DicTraverseSession {
mMultiWordCostMultiplier(1.0f) {
// NOTE: mProximityInfoStates is an array of instances.
// No need to initialize it explicitly here.
for (size_t i = 0; i < NELEMS(mPrevWordsPtNodePos); ++i) {
mPrevWordsPtNodePos[i] = NOT_A_DICT_POS;
for (size_t i = 0; i < NELEMS(mPrevWordsIds); ++i) {
mPrevWordsIds[i] = NOT_A_DICT_POS;
}
}
@ -79,7 +79,7 @@ class DicTraverseSession {
//--------------------
const ProximityInfo *getProximityInfo() const { return mProximityInfo; }
const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; }
const int *getPrevWordsPtNodePos() const { return mPrevWordsPtNodePos; }
const int *getPrevWordIds() const { return mPrevWordsIds; }
DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; }
const ProximityInfoState *getProximityInfoState(int id) const {
@ -166,7 +166,7 @@ class DicTraverseSession {
const int *const inputYs, const int *const times, const int *const pointerIds,
const int inputSize, const float maxSpatialDistance, const int maxPointerCount);
int mPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
int mPrevWordsIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
const ProximityInfo *mProximityInfo;
const Dictionary *mDictionary;
const SuggestOptions *mSuggestOptions;

View File

@ -18,14 +18,12 @@
#define LATINIME_PREV_WORDS_INFO_H
#include "defines.h"
#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "utils/char_utils.h"
#include "utils/int_array_view.h"
namespace latinime {
// TODO: Support n-gram.
class PrevWordsInfo {
public:
// No prev word information.
@ -81,11 +79,10 @@ class PrevWordsInfo {
return false;
}
void getPrevWordsTerminalPtNodePos(
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
int *const outPrevWordsTerminalPtNodePos, const bool tryLowerCaseSearch) const {
void getPrevWordIds(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
int *const outPrevWordIds, const bool tryLowerCaseSearch) const {
for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
outPrevWordsTerminalPtNodePos[i] = getTerminalPtNodePosOfWord(dictStructurePolicy,
outPrevWordIds[i] = getWordId(dictStructurePolicy,
mPrevWordCodePoints[i], mPrevWordCodePointCount[i],
mIsBeginningOfSentence[i], tryLowerCaseSearch);
}
@ -110,12 +107,11 @@ class PrevWordsInfo {
private:
DISALLOW_COPY_AND_ASSIGN(PrevWordsInfo);
static int getTerminalPtNodePosOfWord(
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
const int *const wordCodePoints, const int wordCodePointCount,
const bool isBeginningOfSentence, const bool tryLowerCaseSearch) {
if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) {
return NOT_A_DICT_POS;
return NOT_A_WORD_ID;
}
int codePoints[MAX_WORD_LENGTH];
int codePointCount = wordCodePointCount;
@ -124,21 +120,19 @@ class PrevWordsInfo {
codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints,
codePointCount, MAX_WORD_LENGTH);
if (codePointCount <= 0) {
return NOT_A_DICT_POS;
return NOT_A_WORD_ID;
}
}
const CodePointArrayView codePointArrayView(codePoints, codePointCount);
const int wordPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord(
const int wordId = dictStructurePolicy->getWordId(
codePointArrayView, false /* forceLowerCaseSearch */);
if (wordPtNodePos != NOT_A_DICT_POS || !tryLowerCaseSearch) {
// Return the position when when the word was found or doesn't try lower case
// search.
return wordPtNodePos;
if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) {
// Return the id when when the word was found or doesn't try lower case search.
return wordId;
}
// Check bigrams for lower-cased previous word if original was not found. Useful for
// auto-capitalized words like "The [current_word]".
return dictStructurePolicy->getTerminalPtNodePositionOfWord(
codePointArrayView, true /* forceLowerCaseSearch */);
return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */);
}
void clear() {

View File

@ -92,7 +92,7 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession) const {
// Create a new dic node here
DicNode rootNode;
DicNodeUtils::initAsRoot(traverseSession->getDictionaryStructurePolicy(),
traverseSession->getPrevWordsPtNodePos(), &rootNode);
traverseSession->getPrevWordIds(), &rootNode);
traverseSession->getDicTraverseCache()->copyPushActive(&rootNode);
}
}

View File

@ -104,7 +104,7 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
return codePointCount;
}
int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints,
int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
@ -112,9 +112,9 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArray
wordCodePoints.size(), forceLowerCaseSearch);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
AKLOGE("Dictionary reading error in getWordId().");
}
return ptNodePos;
return getWordIdFromTerminalPtNodePos(ptNodePos);
}
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
@ -133,17 +133,19 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
}
}
int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos,
const int ptNodePos) const {
if (ptNodePos == NOT_A_DICT_POS) {
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
const int wordId) const {
if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY;
}
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
if (prevWordsPtNodePos) {
const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]);
if (prevWordIds) {
const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
while (bigramsIt.hasNext()) {
bigramsIt.next();
@ -157,16 +159,18 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtN
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos,
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
NgramListener *const listener) const {
if (!prevWordsPtNodePos) {
if (!prevWordIds) {
return;
}
const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]);
const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
while (bigramsIt.hasNext()) {
bigramsIt.next();
listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos());
listener->onVisitEntry(bigramsIt.getProbability(),
getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()));
}
}
@ -238,8 +242,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
const int wordPos = getTerminalPtNodePositionOfWord(codePointArrayView,
false /* forceLowerCaseSearch */);
const int wordPos = getTerminalPtNodePosFromWordId(
getWordId(codePointArrayView, false /* forceLowerCaseSearch */));
if (wordPos == NOT_A_DICT_POS) {
AKLOGE("Cannot find terminal PtNode position to add shortcut target.");
return false;
@ -266,8 +270,8 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary.");
return false;
}
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints,
false /* forceLowerCaseSearch */);
const int ptNodePos = getTerminalPtNodePosFromWordId(
getWordId(wordCodePoints, false /* forceLowerCaseSearch */));
if (ptNodePos == NOT_A_DICT_POS) {
return false;
}
@ -295,11 +299,9 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
"length: %zd", bigramProperty->getTargetCodePoints()->size());
return false;
}
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos,
false /* tryLowerCaseSearch */);
// TODO: Support N-gram.
if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) {
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
if (prevWordIds[0] == NOT_A_WORD_ID) {
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
const std::vector<UnigramProperty::ShortcutProperty> shortcuts;
const UnigramProperty beginningOfSentenceUnigramProperty(
@ -311,22 +313,22 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
return false;
}
// Refresh Terminal PtNode positions.
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos,
false /* tryLowerCaseSearch */);
// Refresh word ids.
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
} else {
return false;
}
}
const int word1Pos = getTerminalPtNodePositionOfWord(
const int wordPos = getTerminalPtNodePosFromWordId(getWordId(
CodePointArrayView(*bigramProperty->getTargetCodePoints()),
false /* forceLowerCaseSearch */);
if (word1Pos == NOT_A_DICT_POS) {
false /* forceLowerCaseSearch */));
if (wordPos == NOT_A_DICT_POS) {
return false;
}
bool addedNewBigram = false;
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(prevWordsPtNodePos),
word1Pos, bigramProperty, &addedNewBigram)) {
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos),
wordPos, bigramProperty, &addedNewBigram)) {
if (addedNewBigram) {
mBigramCount++;
}
@ -355,20 +357,19 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
wordCodePoints.size());
}
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos,
false /* tryLowerCaseSerch */);
// TODO: Support N-gram.
if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) {
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */);
if (prevWordIds[0] == NOT_A_WORD_ID) {
return false;
}
const int wordPos = getTerminalPtNodePositionOfWord(wordCodePoints,
false /* forceLowerCaseSearch */);
const int wordPos = getTerminalPtNodePosFromWordId(getWordId(wordCodePoints,
false /* forceLowerCaseSearch */));
if (wordPos == NOT_A_DICT_POS) {
return false;
}
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry(
PtNodePosArrayView::fromObject(prevWordsPtNodePos), wordPos)) {
PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) {
mBigramCount--;
return true;
} else {
@ -449,8 +450,8 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
const CodePointArrayView wordCodePoints) const {
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints,
false /* forceLowerCaseSearch */);
const int ptNodePos = getTerminalPtNodePosFromWordId(
getWordId(wordCodePoints, false /* forceLowerCaseSearch */));
if (ptNodePos == NOT_A_DICT_POS) {
AKLOGE("getWordProperty is called for invalid word.");
return WordProperty();
@ -553,6 +554,14 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
return nextToken;
}
int Ver4PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const {
return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos;
}
int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const {
return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId;
}
} // namespace v402
} // namespace backward
} // namespace latinime

View File

@ -87,15 +87,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const;
int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
void iterateNgramEntries(const int *const prevWordsPtNodePos,
NgramListener *const listener) const;
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const;
@ -164,6 +162,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
mutable bool mIsCorrupted;
int getBigramsPositionOfPtNode(const int ptNodePos) const;
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
int getTerminalPtNodePosFromWordId(const int wordId) const;
};
} // namespace v402
} // namespace backward

View File

@ -267,8 +267,8 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
}
// This function gets the position of the terminal PtNode of the exact matching word in the
// dictionary. If no match is found, it returns NOT_A_DICT_POS.
int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints,
// dictionary. If no match is found, it returns NOT_A_WORD_ID.
int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const {
DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
@ -276,9 +276,9 @@ int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView
wordCodePoints.size(), forceLowerCaseSearch);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
AKLOGE("Dictionary reading error in getWordId().");
}
return ptNodePos;
return getWordIdFromTerminalPtNodePos(ptNodePos);
}
int PatriciaTriePolicy::getProbability(const int unigramProbability,
@ -297,11 +297,11 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability,
}
}
int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos,
const int ptNodePos) const {
if (ptNodePos == NOT_A_DICT_POS) {
int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const int wordId) const {
if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY;
}
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams =
mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) {
@ -310,8 +310,9 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP
// for shortcuts).
return NOT_A_PROBABILITY;
}
if (prevWordsPtNodePos) {
const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]);
if (prevWordIds) {
const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition);
while (bigramsIt.hasNext()) {
bigramsIt.next();
@ -325,16 +326,18 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos,
void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
NgramListener *const listener) const {
if (!prevWordsPtNodePos) {
if (!prevWordIds) {
return;
}
const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]);
const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition);
while (bigramsIt.hasNext()) {
bigramsIt.next();
listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos());
listener->onVisitEntry(bigramsIt.getProbability(),
getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()));
}
}
@ -379,12 +382,12 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod
const WordProperty PatriciaTriePolicy::getWordProperty(
const CodePointArrayView wordCodePoints) const {
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints,
false /* forceLowerCaseSearch */);
if (ptNodePos == NOT_A_DICT_POS) {
const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);
if (wordId == NOT_A_WORD_ID) {
AKLOGE("getWordProperty was called for invalid word.");
return WordProperty();
}
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams =
mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
@ -467,4 +470,11 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC
return nextToken;
}
int PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const {
return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos;
}
int PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const {
return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId;
}
} // namespace latinime

View File

@ -64,15 +64,13 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const;
int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
void iterateNgramEntries(const int *const prevWordsPtNodePos,
NgramListener *const listener) const;
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const;
@ -163,6 +161,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getBigramsPositionOfPtNode(const int ptNodePos) const;
int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos,
DicNodeVector *const childDicNodes) const;
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
int getTerminalPtNodePosFromWordId(const int wordId) const;
};
} // namespace latinime
#endif // LATINIME_PATRICIA_TRIE_POLICY_H

View File

@ -94,7 +94,7 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
return codePointCount;
}
int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints,
int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
@ -104,7 +104,11 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const CodePointArray
mIsCorrupted = true;
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
}
return ptNodePos;
if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_WORD_ID;
}
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
return ptNodeParams.getTerminalId();
}
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
@ -123,24 +127,22 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
}
}
int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos,
const int ptNodePos) const {
if (ptNodePos == NOT_A_DICT_POS) {
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
const int wordId) const {
if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY;
}
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
if (prevWordsPtNodePos) {
if (prevWordIds) {
// TODO: Support n-gram.
const PtNodeParams prevWordPtNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordsPtNodePos[0]);
const int prevWordTerminalId = prevWordPtNodeParams.getTerminalId();
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
IntArrayView::fromObject(&prevWordTerminalId),
ptNodeParams.getTerminalId());
IntArrayView::fromObject(prevWordIds), wordId);
if (!probabilityEntry.isValid()) {
return NOT_A_PROBABILITY;
}
@ -154,26 +156,21 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtN
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos,
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
NgramListener *const listener) const {
if (!prevWordsPtNodePos) {
if (!prevWordIds) {
return;
}
// TODO: Support n-gram.
const PtNodeParams ptNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordsPtNodePos[0]);
const int prevWordId = ptNodeParams.getTerminalId();
const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&prevWordId);
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
for (const auto entry : languageModelDictContent->getProbabilityEntries(prevWordIds)) {
for (const auto entry : languageModelDictContent->getProbabilityEntries(
WordIdArrayView::fromObject(prevWordIds))) {
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
probabilityEntry.getProbability();
const int ptNodePos = mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(
entry.getWordId());
listener->onVisitEntry(probability, ptNodePos);
listener->onVisitEntry(probability, entry.getWordId());
}
}
@ -233,12 +230,13 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
const int wordPos = getTerminalPtNodePositionOfWord(codePointArrayView,
false /* forceLowerCaseSearch */);
if (wordPos == NOT_A_DICT_POS) {
AKLOGE("Cannot find terminal PtNode position to add shortcut target.");
const int wordId = getWordId(codePointArrayView, false /* forceLowerCaseSearch */);
if (wordId == NOT_A_WORD_ID) {
AKLOGE("Cannot find word id to add shortcut target.");
return false;
}
const int wordPos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
for (const auto &shortcut : unigramProperty->getShortcuts()) {
if (!mUpdatingHelper.addShortcutTarget(wordPos,
shortcut.getTargetCodePoints()->data(),
@ -261,20 +259,19 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary.");
return false;
}
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints,
false /* forceLowerCaseSearch */);
if (ptNodePos == NOT_A_DICT_POS) {
const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);
if (wordId == NOT_A_WORD_ID) {
return false;
}
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
if (!mNodeWriter.markPtNodeAsDeleted(&ptNodeParams)) {
AKLOGE("Cannot remove unigram. ptNodePos: %d", ptNodePos);
return false;
}
if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry(
ptNodeParams.getTerminalId())) {
// TODO: Uncomment.
// return false;
if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry(wordId)) {
return false;
}
if (!ptNodeParams.representsNonWordInfo()) {
mUnigramCount--;
@ -302,12 +299,10 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
"length: %zd", bigramProperty->getTargetCodePoints()->size());
return false;
}
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos,
false /* tryLowerCaseSearch */);
const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos);
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
// TODO: Support N-gram.
if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) {
if (prevWordIds[0] == NOT_A_WORD_ID) {
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
const std::vector<UnigramProperty::ShortcutProperty> shortcuts;
const UnigramProperty beginningOfSentenceUnigramProperty(
@ -319,22 +314,27 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
return false;
}
// Refresh Terminal PtNode positions.
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos,
false /* tryLowerCaseSearch */);
// Refresh word ids.
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
} else {
return false;
}
}
const int word1Pos = getTerminalPtNodePositionOfWord(
CodePointArrayView(*bigramProperty->getTargetCodePoints()),
const int wordId = getWordId(CodePointArrayView(*bigramProperty->getTargetCodePoints()),
false /* forceLowerCaseSearch */);
if (word1Pos == NOT_A_DICT_POS) {
if (wordId == NOT_A_WORD_ID) {
return false;
}
bool addedNewEntry = false;
if (mUpdatingHelper.addNgramEntry(prevWordsPtNodePosView, word1Pos, bigramProperty,
&addedNewEntry)) {
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(prevWordIds[i]);
}
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(wordId);
if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
wordPtNodePos, bigramProperty, &addedNewEntry)) {
if (addedNewEntry) {
mBigramCount++;
}
@ -363,20 +363,25 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
wordCodePoints.size());
}
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos,
false /* tryLowerCaseSerch */);
const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos);
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */);
// TODO: Support N-gram.
if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) {
if (prevWordIds[0] == NOT_A_WORD_ID) {
return false;
}
const int wordPos = getTerminalPtNodePositionOfWord(wordCodePoints,
false /* forceLowerCaseSearch */);
if (wordPos == NOT_A_DICT_POS) {
const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);
if (wordId == NOT_A_WORD_ID) {
return false;
}
if (mUpdatingHelper.removeNgramEntry(prevWordsPtNodePosView, wordPos)) {
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(prevWordIds[i]);
}
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(wordId);
if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
wordPtNodePos)) {
mBigramCount--;
return true;
} else {
@ -457,12 +462,13 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
const CodePointArrayView wordCodePoints) const {
const int ptNodePos = getTerminalPtNodePositionOfWord(wordCodePoints,
false /* forceLowerCaseSearch */);
if (ptNodePos == NOT_A_DICT_POS) {
const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);
if (wordId == NOT_A_WORD_ID) {
AKLOGE("getWordProperty is called for invalid word.");
return WordProperty();
}
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
@ -473,7 +479,6 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
// Fetch bigram information.
// TODO: Support n-gram.
std::vector<BigramProperty> bigrams;
const int wordId = ptNodeParams.getTerminalId();
const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId);
const TerminalPositionLookupTable *const terminalPositionLookupTable =
mBuffers->getTerminalPositionLookupTable();

View File

@ -66,15 +66,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const;
int getTerminalPtNodePositionOfWord(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
void iterateNgramEntries(const int *const prevWordsPtNodePos,
NgramListener *const listener) const;
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const;