Truncate entries in language model dict content.

Bug: 14425059

Change-Id: I023c1d5109a2c43fcea3bb11a0fd7198c82891ba
This commit is contained in:
Keisuke Kuroyanagi 2014-08-21 12:48:24 +09:00
parent 9aa6699107
commit 063f86d40f
3 changed files with 152 additions and 0 deletions

View file

@ -16,6 +16,9 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
#include <algorithm>
#include <cstring>
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
namespace latinime { namespace latinime {
@ -68,6 +71,19 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
return mTrieMap.remove(wordId, bitmapEntryIndex); return mTrieMap.remove(wordId, bitmapEntryIndex);
} }
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
if (entryCounts[i] <= maxEntryCounts[i]) {
continue;
}
if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) {
return false;
}
}
return true;
}
bool LanguageModelDictContent::runGCInner( bool LanguageModelDictContent::runGCInner(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange, const TrieMap::TrieMapRange trieMapRange,
@ -162,4 +178,87 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
return true; return true;
} }
bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) {
std::vector<int> prevWordIds;
std::vector<EntryInfoToTurncate> entryInfoVector;
if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
&prevWordIds, &entryInfoVector)) {
return false;
}
if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
return true;
}
const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
entryInfoVector.end(),
EntryInfoToTurncate::Comparator());
for (int i = 0; i < entryCountToRemove; ++i) {
const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
if (!removeNgramProbabilityEntry(
WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
return false;
}
}
return true;
}
bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
const int currentLevel = prevWordIds->size();
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
if (currentLevel < targetLevel) {
if (!entry.hasNextLevelMap()) {
continue;
}
prevWordIds->push_back(entry.key());
if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(),
prevWordIds, outEntryInfo)) {
return false;
}
prevWordIds->pop_back();
continue;
}
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
const int probability = (mHasHistoricalInfo) ?
ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
headerPolicy) : probabilityEntry.getProbability();
outEntryInfo->emplace_back(probability,
probabilityEntry.getHistoricalInfo()->getTimeStamp(),
entry.key(), targetLevel, prevWordIds->data());
}
return true;
}
bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
if (left.mProbability != right.mProbability) {
return left.mProbability < right.mProbability;
}
if (left.mTimestamp != right.mTimestamp) {
return left.mTimestamp > right.mTimestamp;
}
if (left.mKey != right.mKey) {
return left.mKey < right.mKey;
}
if (left.mEntryLevel != right.mEntryLevel) {
return left.mEntryLevel > right.mEntryLevel;
}
for (int i = 0; i < left.mEntryLevel; ++i) {
if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
return left.mPrevWordIds[i] < right.mPrevWordIds[i];
}
}
// left and rigth represent the same entry.
return false;
}
LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
: mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
}
} // namespace latinime } // namespace latinime

View file

@ -18,6 +18,7 @@
#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H #define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
#include <cstdio> #include <cstdio>
#include <vector>
#include "defines.h" #include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
@ -77,13 +78,43 @@ class LanguageModelDictContent {
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy, bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) { int *const outEntryCounts) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
outEntryCounts[i] = 0;
}
return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */, return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
headerPolicy, outEntryCounts); headerPolicy, outEntryCounts);
} }
// entryCounts should be created by updateAllProbabilityEntries.
bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
const HeaderPolicy *const headerPolicy);
private: private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
class EntryInfoToTurncate {
public:
class Comparator {
public:
bool operator()(const EntryInfoToTurncate &left,
const EntryInfoToTurncate &right) const;
private:
DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
};
EntryInfoToTurncate(const int probability, const int timestamp, const int key,
const int entryLevel, const int *const prevWordIds);
int mProbability;
int mTimestamp;
int mKey;
int mEntryLevel;
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
private:
DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
};
TrieMap mTrieMap; TrieMap mTrieMap;
const bool mHasHistoricalInfo; const bool mHasHistoricalInfo;
@ -94,6 +125,11 @@ class LanguageModelDictContent {
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
const HeaderPolicy *const headerPolicy, int *const outEntryCounts); const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
const int maxEntryCount, const int targetLevel);
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
}; };
} // namespace latinime } // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */

View file

@ -91,6 +91,21 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
AKLOGE("Failed to update probabilities in language model dict content."); AKLOGE("Failed to update probabilities in language model dict content.");
return false; return false;
} }
if (headerPolicy->isDecayingDict()) {
int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount();
maxEntryCountTable[1] = headerPolicy->getMaxBigramCount();
for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
// TODO: Have max n-gram count.
maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
}
if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
maxEntryCountTable, headerPolicy)) {
AKLOGE("Failed to truncate entries in language model dict content.");
return false;
}
}
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
DynamicPtGcEventListeners DynamicPtGcEventListeners
@ -193,6 +208,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return true; return true;
} }
// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
const Ver4PatriciaTrieNodeReader *const ptNodeReader, const Ver4PatriciaTrieNodeReader *const ptNodeReader,
Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) { Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
@ -233,6 +249,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
return true; return true;
} }
// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) { bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) {
const TerminalPositionLookupTable *const terminalPosLookupTable = const TerminalPositionLookupTable *const terminalPosLookupTable =
mBuffers->getTerminalPositionLookupTable(); mBuffers->getTerminalPositionLookupTable();