Get entry count after truncation using LanguageModelDictContent.

Bug: 14425059
Change-Id: I41b237c1c22c21740946d52e3be9d6f963c9cd54
This commit is contained in:
Keisuke Kuroyanagi 2014-08-27 20:04:39 +09:00
parent c7f1de826c
commit 758d093644
3 changed files with 26 additions and 9 deletions

View file

@ -23,6 +23,9 @@
namespace latinime { namespace latinime {
const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
bool LanguageModelDictContent::save(FILE *const file) const { bool LanguageModelDictContent::save(FILE *const file) const {
return mTrieMap.save(file); return mTrieMap.save(file);
} }
@ -78,12 +81,15 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
} }
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts, bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) { const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
if (entryCounts[i] <= maxEntryCounts[i]) { if (entryCounts[i] <= maxEntryCounts[i]) {
outEntryCounts[i] = entryCounts[i];
continue; continue;
} }
if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) { if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i,
&outEntryCounts[i])) {
return false; return false;
} }
} }
@ -185,7 +191,8 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
} }
bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) { const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel,
int *const outEntryCount) {
std::vector<int> prevWordIds; std::vector<int> prevWordIds;
std::vector<EntryInfoToTurncate> entryInfoVector; std::vector<EntryInfoToTurncate> entryInfoVector;
if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(), if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
@ -193,8 +200,10 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
return false; return false;
} }
if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) { if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
*outEntryCount = static_cast<int>(entryInfoVector.size());
return true; return true;
} }
*outEntryCount = maxEntryCount;
const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount; const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove, std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
entryInfoVector.end(), entryInfoVector.end(),

View file

@ -39,6 +39,9 @@ class HeaderPolicy;
*/ */
class LanguageModelDictContent { class LanguageModelDictContent {
public: public:
static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
// Pair of word id and probability entry used for iteration. // Pair of word id and probability entry used for iteration.
class WordIdAndProbabilityEntry { class WordIdAndProbabilityEntry {
public: public:
@ -158,7 +161,7 @@ class LanguageModelDictContent {
// entryCounts should be created by updateAllProbabilityEntries. // entryCounts should be created by updateAllProbabilityEntries.
bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts, bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
const HeaderPolicy *const headerPolicy); const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
private: private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
@ -197,7 +200,7 @@ class LanguageModelDictContent {
bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
const HeaderPolicy *const headerPolicy, int *const outEntryCounts); const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
const int maxEntryCount, const int targetLevel); const int maxEntryCount, const int targetLevel, int *const outEntryCount);
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
const int bitmapEntryIndex, std::vector<int> *const prevWordIds, const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
std::vector<EntryInfoToTurncate> *const outEntryInfo) const; std::vector<EntryInfoToTurncate> *const outEntryInfo) const;

View file

@ -93,14 +93,16 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
} }
if (headerPolicy->isDecayingDict()) { if (headerPolicy->isDecayingDict()) {
int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount(); maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
maxEntryCountTable[1] = headerPolicy->getMaxBigramCount(); headerPolicy->getMaxUnigramCount();
maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
headerPolicy->getMaxBigramCount();
for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) { for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
// TODO: Have max n-gram count. // TODO: Have max n-gram count.
maxEntryCountTable[i] = headerPolicy->getMaxBigramCount(); maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
} }
if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable, if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
maxEntryCountTable, headerPolicy)) { maxEntryCountTable, headerPolicy, entryCountTable)) {
AKLOGE("Failed to truncate entries in language model dict content."); AKLOGE("Failed to truncate entries in language model dict content.");
return false; return false;
} }
@ -204,7 +206,10 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) { &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) {
return false; return false;
} }
*outUnigramCount = traversePolicyToUpdateAllPositionFields.getUnigramCount(); *outUnigramCount =
entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
*outBigramCount =
entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
return true; return true;
} }