Merge "Use EntryCounters during GC."

This commit is contained in:
Keisuke Kuroyanagi 2014-10-21 07:55:03 +00:00 committed by Android (Google) Code Review
commit fa1e65cb3a
5 changed files with 54 additions and 65 deletions

View file

@ -23,8 +23,6 @@
namespace latinime { namespace latinime {
const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1; const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
bool LanguageModelDictContent::save(FILE *const file) const { bool LanguageModelDictContent::save(FILE *const file) const {
@ -33,10 +31,9 @@ bool LanguageModelDictContent::save(FILE *const file) const {
bool LanguageModelDictContent::runGC( bool LanguageModelDictContent::runGC(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const LanguageModelDictContent *const originalContent, const LanguageModelDictContent *const originalContent) {
int *const outNgramCount) {
return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(), return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
0 /* nextLevelBitmapEntryIndex */, outNgramCount); 0 /* nextLevelBitmapEntryIndex */);
} }
const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
@ -143,18 +140,23 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo); return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
} }
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts, bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy, const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) { MutableEntryCounters *const outEntryCounters) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
if (entryCounts[i] <= maxEntryCounts[i]) { const int totalWordCount = prevWordCount + 1;
outEntryCounts[i] = entryCounts[i]; if (currentEntryCounts.getNgramCount(totalWordCount)
<= maxEntryCounts.getNgramCount(totalWordCount)) {
outEntryCounters->setNgramCount(totalWordCount,
currentEntryCounts.getNgramCount(totalWordCount));
continue; continue;
} }
if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i, int entryCount = 0;
&outEntryCounts[i])) { if (!turncateEntriesInSpecifiedLevel(headerPolicy,
maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
return false; return false;
} }
outEntryCounters->setNgramCount(totalWordCount, entryCount);
} }
return true; return true;
} }
@ -208,8 +210,7 @@ const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
bool LanguageModelDictContent::runGCInner( bool LanguageModelDictContent::runGCInner(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange, const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) {
const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
for (auto &entry : trieMapRange) { for (auto &entry : trieMapRange) {
const auto it = terminalIdMap->find(entry.key()); const auto it = terminalIdMap->find(entry.key());
if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) { if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
@ -219,13 +220,9 @@ bool LanguageModelDictContent::runGCInner(
if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) { if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
return false; return false;
} }
if (outNgramCount) {
*outNgramCount += 1;
}
if (entry.hasNextLevelMap()) { if (entry.hasNextLevelMap()) {
if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(), if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex), mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) {
outNgramCount)) {
return false; return false;
} }
} }
@ -268,7 +265,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
const int prevWordCount, const HeaderPolicy *const headerPolicy, const int prevWordCount, const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) { MutableEntryCounters *const outEntryCounters) {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@ -305,13 +302,13 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
} }
} }
if (!probabilityEntry.representsBeginningOfSentence()) { if (!probabilityEntry.representsBeginningOfSentence()) {
outEntryCounts[prevWordCount] += 1; outEntryCounters->incrementNgramCount(prevWordCount + 1);
} }
if (!entry.hasNextLevelMap()) { if (!entry.hasNextLevelMap()) {
continue; continue;
} }
if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
prevWordCount + 1, headerPolicy, outEntryCounts)) { prevWordCount + 1, headerPolicy, outEntryCounters)) {
return false; return false;
} }
} }

View file

@ -41,9 +41,6 @@ class HeaderPolicy;
*/ */
class LanguageModelDictContent { class LanguageModelDictContent {
public: public:
static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
// Pair of word id and probability entry used for iteration. // Pair of word id and probability entry used for iteration.
class WordIdAndProbabilityEntry { class WordIdAndProbabilityEntry {
public: public:
@ -127,8 +124,7 @@ class LanguageModelDictContent {
bool save(FILE *const file) const; bool save(FILE *const file) const;
bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const LanguageModelDictContent *const originalContent, const LanguageModelDictContent *const originalContent);
int *const outNgramCount);
const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
const HeaderPolicy *const headerPolicy) const; const HeaderPolicy *const headerPolicy) const;
@ -156,17 +152,14 @@ class LanguageModelDictContent {
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const; EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) { MutableEntryCounters *const outEntryCounters) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
outEntryCounts[i] = 0;
}
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
0 /* prevWordCount */, headerPolicy, outEntryCounts); 0 /* prevWordCount */, headerPolicy, outEntryCounters);
} }
// entryCounts should be created by updateAllProbabilityEntries. // entryCounts should be created by updateAllProbabilityEntries.
bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts, bool truncateEntries(const EntryCounts &currentEntryCounts, const EntryCounts &maxEntryCounts,
const HeaderPolicy *const headerPolicy, int *const outEntryCounts); const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId, bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
const bool isValid, const HistoricalInfo historicalInfo, const bool isValid, const HistoricalInfo historicalInfo,
@ -206,12 +199,11 @@ class LanguageModelDictContent {
const bool mHasHistoricalInfo; const bool mHasHistoricalInfo;
bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex, const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex);
int *const outNgramCount);
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
const HeaderPolicy *const headerPolicy, int *const outEntryCounts); const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
const int maxEntryCount, const int targetLevel, int *const outEntryCount); const int maxEntryCount, const int targetLevel, int *const outEntryCount);
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,

View file

@ -57,16 +57,14 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers( Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers(
Ver4DictBuffers::createVer4DictBuffers(headerPolicy, Ver4DictBuffers::createVer4DictBuffers(headerPolicy,
Ver4DictConstants::MAX_DICTIONARY_SIZE)); Ver4DictConstants::MAX_DICTIONARY_SIZE));
int unigramCount = 0; MutableEntryCounters entryCounters;
int bigramCount = 0; if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &entryCounters)) {
if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &unigramCount, &bigramCount)) {
return false; return false;
} }
BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */), entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
0 /* extendedRegionSize */, &headerBuffer)) {
return false; return false;
} }
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@ -74,7 +72,7 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite, const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite,
int *const outUnigramCount, int *const outBigramCount) { MutableEntryCounters *const outEntryCounters) {
Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer()); Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer());
Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer()); Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer());
Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(), Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
@ -82,24 +80,17 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(), Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC( if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC(
headerPolicy, entryCountTable)) { headerPolicy, outEntryCounters)) {
AKLOGE("Failed to update probabilities in language model dict content."); AKLOGE("Failed to update probabilities in language model dict content.");
return false; return false;
} }
if (headerPolicy->isDecayingDict()) { if (headerPolicy->isDecayingDict()) {
int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(),
maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] = headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
headerPolicy->getMaxUnigramCount(); if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] = outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
headerPolicy->getMaxBigramCount(); outEntryCounters)) {
for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
// TODO: Have max n-gram count.
maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
}
if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
maxEntryCountTable, headerPolicy, entryCountTable)) {
AKLOGE("Failed to truncate entries in language model dict content."); AKLOGE("Failed to truncate entries in language model dict content.");
return false; return false;
} }
@ -143,9 +134,9 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&terminalIdMap)) { &terminalIdMap)) {
return false; return false;
} }
// Run GC for probability dict content. // Run GC for language model dict content.
if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap, if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap,
mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) { mBuffers->getLanguageModelDictContent())) {
return false; return false;
} }
// Run GC for shortcut dict content. // Run GC for shortcut dict content.
@ -168,10 +159,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) { &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) {
return false; return false;
} }
*outUnigramCount =
entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
*outBigramCount =
entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
return true; return true;
} }

View file

@ -67,8 +67,7 @@ class Ver4PatriciaTrieWritingHelper {
}; };
bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy,
Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount, Ver4DictBuffers *const buffersToWrite, MutableEntryCounters *const outEntryCounters);
int *const outBigramCount);
Ver4DictBuffers *const mBuffers; Ver4DictBuffers *const mBuffers;
}; };

View file

@ -46,6 +46,13 @@ class EntryCounts final {
return mEntryCounts[2]; return mEntryCounts[2];
} }
int getNgramCount(const size_t n) const {
if (n < 1 || n > mEntryCounts.size()) {
return 0;
}
return mEntryCounts[n - 1];
}
private: private:
DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts); DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);
@ -110,6 +117,13 @@ class MutableEntryCounters final {
--mEntryCounters[n - 1]; --mEntryCounters[n - 1];
} }
void setNgramCount(const size_t n, const int count) {
if (n < 1 || n > mEntryCounters.size()) {
return;
}
mEntryCounters[n - 1] = count;
}
private: private:
DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters); DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);