Merge "Move entry updating method to language model dict content."
This commit is contained in:
commit
c2429c54ac
7 changed files with 116 additions and 73 deletions
|
@ -25,6 +25,7 @@ namespace latinime {
|
||||||
|
|
||||||
const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
|
const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
|
||||||
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
|
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
|
||||||
|
const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
|
||||||
|
|
||||||
bool LanguageModelDictContent::save(FILE *const file) const {
|
bool LanguageModelDictContent::save(FILE *const file) const {
|
||||||
return mTrieMap.save(file);
|
return mTrieMap.save(file);
|
||||||
|
@ -143,6 +144,56 @@ bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds,
|
||||||
|
const int wordId, const bool isValid, const HistoricalInfo historicalInfo,
|
||||||
|
const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount) {
|
||||||
|
if (outAddedNewNgramEntryCount) {
|
||||||
|
*outAddedNewNgramEntryCount = 0;
|
||||||
|
}
|
||||||
|
if (!mHasHistoricalInfo) {
|
||||||
|
AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const ProbabilityEntry originalUnigramProbabilityEntry = getProbabilityEntry(wordId);
|
||||||
|
const ProbabilityEntry updatedUnigramProbabilityEntry = createUpdatedEntryFrom(
|
||||||
|
originalUnigramProbabilityEntry, isValid, historicalInfo, headerPolicy);
|
||||||
|
if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < prevWordIds.size(); ++i) {
|
||||||
|
if (prevWordIds[i] == NOT_A_WORD_ID) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// TODO: Optimize this code.
|
||||||
|
const WordIdArrayView limitedPrevWordIds = prevWordIds.limit(i + 1);
|
||||||
|
const ProbabilityEntry originalNgramProbabilityEntry = getNgramProbabilityEntry(
|
||||||
|
limitedPrevWordIds, wordId);
|
||||||
|
const ProbabilityEntry updatedNgramProbabilityEntry = createUpdatedEntryFrom(
|
||||||
|
originalNgramProbabilityEntry, isValid, historicalInfo, headerPolicy);
|
||||||
|
if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!originalNgramProbabilityEntry.isValid() && outAddedNewNgramEntryCount) {
|
||||||
|
*outAddedNewNgramEntryCount += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
|
||||||
|
const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
|
||||||
|
const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
|
||||||
|
const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo(
|
||||||
|
originalProbabilityEntry.getHistoricalInfo(), isValid ?
|
||||||
|
DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY,
|
||||||
|
&historicalInfo, headerPolicy);
|
||||||
|
if (originalProbabilityEntry.isValid()) {
|
||||||
|
return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
|
||||||
|
} else {
|
||||||
|
return ProbabilityEntry(0 /* flags */, &updatedHistoricalInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool LanguageModelDictContent::runGCInner(
|
bool LanguageModelDictContent::runGCInner(
|
||||||
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
|
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
|
||||||
const TrieMap::TrieMapRange trieMapRange,
|
const TrieMap::TrieMapRange trieMapRange,
|
||||||
|
@ -203,7 +254,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
|
||||||
return bitmapEntryIndex;
|
return bitmapEntryIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
|
bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
|
||||||
const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
|
const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
|
||||||
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
|
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
|
||||||
if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
|
if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
|
||||||
|
@ -237,7 +288,7 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
|
||||||
if (!entry.hasNextLevelMap()) {
|
if (!entry.hasNextLevelMap()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
|
if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
|
||||||
headerPolicy, outEntryCounts)) {
|
headerPolicy, outEntryCounts)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -154,19 +154,23 @@ class LanguageModelDictContent {
|
||||||
|
|
||||||
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
|
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
|
||||||
|
|
||||||
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
|
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
|
||||||
int *const outEntryCounts) {
|
int *const outEntryCounts) {
|
||||||
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
||||||
outEntryCounts[i] = 0;
|
outEntryCounts[i] = 0;
|
||||||
}
|
}
|
||||||
return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
|
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
|
||||||
headerPolicy, outEntryCounts);
|
0 /* level */, headerPolicy, outEntryCounts);
|
||||||
}
|
}
|
||||||
|
|
||||||
// entryCounts should be created by updateAllProbabilityEntries.
|
// entryCounts should be created by updateAllProbabilityEntries.
|
||||||
bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
|
bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
|
||||||
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
|
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
|
||||||
|
|
||||||
|
bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
|
||||||
|
const bool isValid, const HistoricalInfo historicalInfo,
|
||||||
|
const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
|
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
|
||||||
|
|
||||||
|
@ -193,6 +197,9 @@ class LanguageModelDictContent {
|
||||||
DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
|
DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// TODO: Remove
|
||||||
|
static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;
|
||||||
|
|
||||||
TrieMap mTrieMap;
|
TrieMap mTrieMap;
|
||||||
const bool mHasHistoricalInfo;
|
const bool mHasHistoricalInfo;
|
||||||
|
|
||||||
|
@ -201,13 +208,16 @@ class LanguageModelDictContent {
|
||||||
int *const outNgramCount);
|
int *const outNgramCount);
|
||||||
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
|
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
|
||||||
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
|
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
|
||||||
bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
|
bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int level,
|
||||||
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
|
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
|
||||||
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
|
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
|
||||||
const int maxEntryCount, const int targetLevel, int *const outEntryCount);
|
const int maxEntryCount, const int targetLevel, int *const outEntryCount);
|
||||||
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
|
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
|
||||||
const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
|
const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
|
||||||
std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
|
std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
|
||||||
|
const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
|
||||||
|
const bool isValid, const HistoricalInfo historicalInfo,
|
||||||
|
const HeaderPolicy *const headerPolicy) const;
|
||||||
};
|
};
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
|
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
|
||||||
|
|
|
@ -142,14 +142,9 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
|
||||||
if (!toBeUpdatedPtNodeParams->isTerminal()) {
|
if (!toBeUpdatedPtNodeParams->isTerminal()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const ProbabilityEntry originalProbabilityEntry =
|
|
||||||
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
|
|
||||||
toBeUpdatedPtNodeParams->getTerminalId());
|
|
||||||
const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
|
const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
|
||||||
const ProbabilityEntry updatedProbabilityEntry =
|
|
||||||
createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
|
|
||||||
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
||||||
toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
|
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntryOfUnigramProperty);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
|
bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
|
||||||
|
@ -203,10 +198,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
|
||||||
// Write probability.
|
// Write probability.
|
||||||
ProbabilityEntry newProbabilityEntry;
|
ProbabilityEntry newProbabilityEntry;
|
||||||
const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
|
const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
|
||||||
const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
|
|
||||||
&newProbabilityEntry, &probabilityEntryOfUnigramProperty);
|
|
||||||
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
||||||
terminalId, &probabilityEntryToWrite);
|
terminalId, &probabilityEntryOfUnigramProperty);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Support counting ngram entries.
|
// TODO: Support counting ngram entries.
|
||||||
|
@ -217,10 +210,8 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
|
||||||
const ProbabilityEntry probabilityEntry =
|
const ProbabilityEntry probabilityEntry =
|
||||||
languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId);
|
languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId);
|
||||||
const ProbabilityEntry probabilityEntryOfNgramProperty(ngramProperty);
|
const ProbabilityEntry probabilityEntryOfNgramProperty(ngramProperty);
|
||||||
const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
|
|
||||||
&probabilityEntry, &probabilityEntryOfNgramProperty);
|
|
||||||
if (!languageModelDictContent->setNgramProbabilityEntry(
|
if (!languageModelDictContent->setNgramProbabilityEntry(
|
||||||
prevWordIds, wordId, &updatedProbabilityEntry)) {
|
prevWordIds, wordId, &probabilityEntryOfNgramProperty)) {
|
||||||
AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d",
|
AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d",
|
||||||
prevWordIds[0], prevWordIds.size(), wordId);
|
prevWordIds[0], prevWordIds.size(), wordId);
|
||||||
return false;
|
return false;
|
||||||
|
@ -346,22 +337,6 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
|
||||||
ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
|
ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Move probability handling code to LanguageModelDictContent.
|
|
||||||
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
|
|
||||||
const ProbabilityEntry *const originalProbabilityEntry,
|
|
||||||
const ProbabilityEntry *const probabilityEntry) const {
|
|
||||||
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
|
|
||||||
const HistoricalInfo updatedHistoricalInfo =
|
|
||||||
ForgettingCurveUtils::createUpdatedHistoricalInfo(
|
|
||||||
originalProbabilityEntry->getHistoricalInfo(),
|
|
||||||
probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
|
|
||||||
mHeaderPolicy);
|
|
||||||
return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo);
|
|
||||||
} else {
|
|
||||||
return *probabilityEntry;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal,
|
bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal,
|
||||||
const bool hasMultipleChars) {
|
const bool hasMultipleChars) {
|
||||||
// Create node flags and write them.
|
// Create node flags and write them.
|
||||||
|
|
|
@ -38,11 +38,10 @@ class Ver4ShortcutListPolicy;
|
||||||
class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
|
class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
|
||||||
public:
|
public:
|
||||||
Ver4PatriciaTrieNodeWriter(BufferWithExtendableBuffer *const trieBuffer,
|
Ver4PatriciaTrieNodeWriter(BufferWithExtendableBuffer *const trieBuffer,
|
||||||
Ver4DictBuffers *const buffers, const HeaderPolicy *const headerPolicy,
|
Ver4DictBuffers *const buffers, const PtNodeReader *const ptNodeReader,
|
||||||
const PtNodeReader *const ptNodeReader,
|
|
||||||
const PtNodeArrayReader *const ptNodeArrayReader,
|
const PtNodeArrayReader *const ptNodeArrayReader,
|
||||||
Ver4ShortcutListPolicy *const shortcutPolicy)
|
Ver4ShortcutListPolicy *const shortcutPolicy)
|
||||||
: mTrieBuffer(trieBuffer), mBuffers(buffers), mHeaderPolicy(headerPolicy),
|
: mTrieBuffer(trieBuffer), mBuffers(buffers),
|
||||||
mReadingHelper(ptNodeReader, ptNodeArrayReader), mShortcutPolicy(shortcutPolicy) {}
|
mReadingHelper(ptNodeReader, ptNodeArrayReader), mShortcutPolicy(shortcutPolicy) {}
|
||||||
|
|
||||||
virtual ~Ver4PatriciaTrieNodeWriter() {}
|
virtual ~Ver4PatriciaTrieNodeWriter() {}
|
||||||
|
@ -96,20 +95,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
|
||||||
const PtNodeParams *const ptNodeParams, int *const outTerminalId,
|
const PtNodeParams *const ptNodeParams, int *const outTerminalId,
|
||||||
int *const ptNodeWritingPos);
|
int *const ptNodeWritingPos);
|
||||||
|
|
||||||
// Create updated probability entry using given probability property. In addition to the
|
|
||||||
// probability, this method updates historical information if needed.
|
|
||||||
// TODO: Update flags.
|
|
||||||
const ProbabilityEntry createUpdatedEntryFrom(
|
|
||||||
const ProbabilityEntry *const originalProbabilityEntry,
|
|
||||||
const ProbabilityEntry *const probabilityEntry) const;
|
|
||||||
|
|
||||||
bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars);
|
bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars);
|
||||||
|
|
||||||
static const int CHILDREN_POSITION_FIELD_SIZE;
|
static const int CHILDREN_POSITION_FIELD_SIZE;
|
||||||
|
|
||||||
BufferWithExtendableBuffer *const mTrieBuffer;
|
BufferWithExtendableBuffer *const mTrieBuffer;
|
||||||
Ver4DictBuffers *const mBuffers;
|
Ver4DictBuffers *const mBuffers;
|
||||||
const HeaderPolicy *const mHeaderPolicy;
|
|
||||||
DynamicPtReadingHelper mReadingHelper;
|
DynamicPtReadingHelper mReadingHelper;
|
||||||
Ver4ShortcutListPolicy *const mShortcutPolicy;
|
Ver4ShortcutListPolicy *const mShortcutPolicy;
|
||||||
};
|
};
|
||||||
|
|
|
@ -43,7 +43,6 @@ const char *const Ver4PatriciaTriePolicy::MAX_BIGRAM_COUNT_QUERY = "MAX_BIGRAM_C
|
||||||
const int Ver4PatriciaTriePolicy::MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS = 1024;
|
const int Ver4PatriciaTriePolicy::MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS = 1024;
|
||||||
const int Ver4PatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS =
|
const int Ver4PatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS =
|
||||||
Ver4DictConstants::MAX_DICTIONARY_SIZE - MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS;
|
Ver4DictConstants::MAX_DICTIONARY_SIZE - MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS;
|
||||||
const int Ver4PatriciaTriePolicy::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
|
|
||||||
|
|
||||||
void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode,
|
void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode,
|
||||||
DicNodeVector *const childDicNodes) const {
|
DicNodeVector *const childDicNodes) const {
|
||||||
|
@ -151,8 +150,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
|
||||||
}
|
}
|
||||||
const int probability = probabilityEntry.hasHistoricalInfo() ?
|
const int probability = probabilityEntry.hasHistoricalInfo() ?
|
||||||
ForgettingCurveUtils::decodeProbability(
|
ForgettingCurveUtils::decodeProbability(
|
||||||
probabilityEntry.getHistoricalInfo(), mHeaderPolicy)
|
probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
|
||||||
+ ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */) :
|
|
||||||
probabilityEntry.getProbability();
|
probabilityEntry.getProbability();
|
||||||
listener->onVisitEntry(probability, entry.getWordId());
|
listener->onVisitEntry(probability, entry.getWordId());
|
||||||
}
|
}
|
||||||
|
@ -371,25 +369,44 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
|
||||||
"dictionary.");
|
"dictionary.");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// TODO: Have count up method in language model dict content.
|
const bool updateAsAValidWord = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */) ?
|
||||||
const int probability = isValidWord ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY;
|
false : isValidWord;
|
||||||
|
int wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
|
||||||
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
|
// The word is not in the dictionary.
|
||||||
const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */,
|
const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */,
|
||||||
false /* isNotAWord */, false /* isBlacklisted */, probability, historicalInfo);
|
false /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
|
||||||
|
HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
|
||||||
if (!addUnigramEntry(wordCodePoints, &unigramProperty)) {
|
if (!addUnigramEntry(wordCodePoints, &unigramProperty)) {
|
||||||
AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext().");
|
AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext().");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const int probabilityForNgram = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)
|
wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
|
||||||
? NOT_A_PROBABILITY : probability;
|
}
|
||||||
const NgramProperty ngramProperty(wordCodePoints.toVector(), probabilityForNgram,
|
|
||||||
historicalInfo);
|
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
|
||||||
for (size_t i = 1; i <= ngramContext->getPrevWordCount(); ++i) {
|
const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
|
||||||
const NgramContext trimmedNgramContext(ngramContext->getTrimmedNgramContext(i));
|
false /* tryLowerCaseSearch */);
|
||||||
if (!addNgramEntry(&trimmedNgramContext, &ngramProperty)) {
|
if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID
|
||||||
AKLOGE("Cannot update ngram entry in updateEntriesForWordWithNgramContext().");
|
&& ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
|
||||||
|
const UnigramProperty beginningOfSentenceUnigramProperty(
|
||||||
|
true /* representsBeginningOfSentence */,
|
||||||
|
true /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
|
||||||
|
HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
|
||||||
|
if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
|
||||||
|
&beginningOfSentenceUnigramProperty)) {
|
||||||
|
AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext().");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// Refresh word ids.
|
||||||
|
ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
|
||||||
}
|
}
|
||||||
|
int addedNewNgramEntryCount = 0;
|
||||||
|
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds,
|
||||||
|
wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &addedNewNgramEntryCount)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
mBigramCount += addedNewNgramEntryCount;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -47,8 +47,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
mShortcutPolicy(mBuffers->getMutableShortcutDictContent(),
|
mShortcutPolicy(mBuffers->getMutableShortcutDictContent(),
|
||||||
mBuffers->getTerminalPositionLookupTable()),
|
mBuffers->getTerminalPositionLookupTable()),
|
||||||
mNodeReader(mDictBuffer), mPtNodeArrayReader(mDictBuffer),
|
mNodeReader(mDictBuffer), mPtNodeArrayReader(mDictBuffer),
|
||||||
mNodeWriter(mDictBuffer, mBuffers.get(), mHeaderPolicy, &mNodeReader,
|
mNodeWriter(mDictBuffer, mBuffers.get(), &mNodeReader, &mPtNodeArrayReader,
|
||||||
&mPtNodeArrayReader, &mShortcutPolicy),
|
&mShortcutPolicy),
|
||||||
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
|
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
|
||||||
mWritingHelper(mBuffers.get()),
|
mWritingHelper(mBuffers.get()),
|
||||||
mUnigramCount(mHeaderPolicy->getUnigramCount()),
|
mUnigramCount(mHeaderPolicy->getUnigramCount()),
|
||||||
|
@ -131,8 +131,6 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
// prevent the dictionary from overflowing.
|
// prevent the dictionary from overflowing.
|
||||||
static const int MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS;
|
static const int MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS;
|
||||||
static const int MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS;
|
static const int MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS;
|
||||||
// TODO: Remove
|
|
||||||
static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;
|
|
||||||
|
|
||||||
const Ver4DictBuffers::Ver4DictBuffersPtr mBuffers;
|
const Ver4DictBuffers::Ver4DictBuffersPtr mBuffers;
|
||||||
const HeaderPolicy *const mHeaderPolicy;
|
const HeaderPolicy *const mHeaderPolicy;
|
||||||
|
@ -144,6 +142,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
DynamicPtUpdatingHelper mUpdatingHelper;
|
DynamicPtUpdatingHelper mUpdatingHelper;
|
||||||
Ver4PatriciaTrieWritingHelper mWritingHelper;
|
Ver4PatriciaTrieWritingHelper mWritingHelper;
|
||||||
int mUnigramCount;
|
int mUnigramCount;
|
||||||
|
// TODO: Support counting ngram entries.
|
||||||
int mBigramCount;
|
int mBigramCount;
|
||||||
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
|
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
|
||||||
mutable bool mIsCorrupted;
|
mutable bool mIsCorrupted;
|
||||||
|
|
|
@ -78,11 +78,11 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
|
||||||
Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
|
Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
|
||||||
mBuffers->getTerminalPositionLookupTable());
|
mBuffers->getTerminalPositionLookupTable());
|
||||||
Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
|
Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
|
||||||
mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
|
mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
|
||||||
|
|
||||||
int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||||
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
|
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC(
|
||||||
entryCountTable)) {
|
headerPolicy, entryCountTable)) {
|
||||||
AKLOGE("Failed to update probabilities in language model dict content.");
|
AKLOGE("Failed to update probabilities in language model dict content.");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -118,7 +118,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
|
||||||
PtNodeWriter::DictPositionRelocationMap dictPositionRelocationMap;
|
PtNodeWriter::DictPositionRelocationMap dictPositionRelocationMap;
|
||||||
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
|
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
|
||||||
Ver4PatriciaTrieNodeWriter ptNodeWriterForNewBuffers(buffersToWrite->getWritableTrieBuffer(),
|
Ver4PatriciaTrieNodeWriter ptNodeWriterForNewBuffers(buffersToWrite->getWritableTrieBuffer(),
|
||||||
buffersToWrite, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
|
buffersToWrite, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
|
||||||
DynamicPtGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer
|
DynamicPtGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer
|
||||||
traversePolicyToPlaceAndWriteValidPtNodesToBuffer(&ptNodeWriterForNewBuffers,
|
traversePolicyToPlaceAndWriteValidPtNodesToBuffer(&ptNodeWriterForNewBuffers,
|
||||||
buffersToWrite->getWritableTrieBuffer(), &dictPositionRelocationMap);
|
buffersToWrite->getWritableTrieBuffer(), &dictPositionRelocationMap);
|
||||||
|
@ -133,7 +133,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
|
||||||
Ver4ShortcutListPolicy newShortcutPolicy(buffersToWrite->getMutableShortcutDictContent(),
|
Ver4ShortcutListPolicy newShortcutPolicy(buffersToWrite->getMutableShortcutDictContent(),
|
||||||
buffersToWrite->getTerminalPositionLookupTable());
|
buffersToWrite->getTerminalPositionLookupTable());
|
||||||
Ver4PatriciaTrieNodeWriter newPtNodeWriter(buffersToWrite->getWritableTrieBuffer(),
|
Ver4PatriciaTrieNodeWriter newPtNodeWriter(buffersToWrite->getWritableTrieBuffer(),
|
||||||
buffersToWrite, headerPolicy, &newPtNodeReader, &newPtNodeArrayreader,
|
buffersToWrite, &newPtNodeReader, &newPtNodeArrayreader,
|
||||||
&newShortcutPolicy);
|
&newShortcutPolicy);
|
||||||
// Re-assign terminal IDs for valid terminal PtNodes.
|
// Re-assign terminal IDs for valid terminal PtNodes.
|
||||||
TerminalPositionLookupTable::TerminalIdMap terminalIdMap;
|
TerminalPositionLookupTable::TerminalIdMap terminalIdMap;
|
||||||
|
|
Loading…
Reference in a new issue