Merge "Truncate entries in language model dict content."
This commit is contained in:
commit
394677674b
3 changed files with 152 additions and 0 deletions
|
@ -16,6 +16,9 @@
|
|||
|
||||
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
|
||||
|
||||
namespace latinime {
|
||||
|
@ -68,6 +71,19 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
|
|||
return mTrieMap.remove(wordId, bitmapEntryIndex);
|
||||
}
|
||||
|
||||
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
|
||||
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
|
||||
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
||||
if (entryCounts[i] <= maxEntryCounts[i]) {
|
||||
continue;
|
||||
}
|
||||
if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LanguageModelDictContent::runGCInner(
|
||||
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
|
||||
const TrieMap::TrieMapRange trieMapRange,
|
||||
|
@ -162,4 +178,87 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
|
|||
return true;
|
||||
}
|
||||
|
||||
bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
|
||||
const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) {
|
||||
std::vector<int> prevWordIds;
|
||||
std::vector<EntryInfoToTurncate> entryInfoVector;
|
||||
if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
|
||||
&prevWordIds, &entryInfoVector)) {
|
||||
return false;
|
||||
}
|
||||
if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
|
||||
return true;
|
||||
}
|
||||
const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
|
||||
std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
|
||||
entryInfoVector.end(),
|
||||
EntryInfoToTurncate::Comparator());
|
||||
for (int i = 0; i < entryCountToRemove; ++i) {
|
||||
const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
|
||||
if (!removeNgramProbabilityEntry(
|
||||
WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
|
||||
const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
|
||||
std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
|
||||
const int currentLevel = prevWordIds->size();
|
||||
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
|
||||
if (currentLevel < targetLevel) {
|
||||
if (!entry.hasNextLevelMap()) {
|
||||
continue;
|
||||
}
|
||||
prevWordIds->push_back(entry.key());
|
||||
if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(),
|
||||
prevWordIds, outEntryInfo)) {
|
||||
return false;
|
||||
}
|
||||
prevWordIds->pop_back();
|
||||
continue;
|
||||
}
|
||||
const ProbabilityEntry probabilityEntry =
|
||||
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
|
||||
const int probability = (mHasHistoricalInfo) ?
|
||||
ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
|
||||
headerPolicy) : probabilityEntry.getProbability();
|
||||
outEntryInfo->emplace_back(probability,
|
||||
probabilityEntry.getHistoricalInfo()->getTimeStamp(),
|
||||
entry.key(), targetLevel, prevWordIds->data());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
|
||||
const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
|
||||
if (left.mProbability != right.mProbability) {
|
||||
return left.mProbability < right.mProbability;
|
||||
}
|
||||
if (left.mTimestamp != right.mTimestamp) {
|
||||
return left.mTimestamp > right.mTimestamp;
|
||||
}
|
||||
if (left.mKey != right.mKey) {
|
||||
return left.mKey < right.mKey;
|
||||
}
|
||||
if (left.mEntryLevel != right.mEntryLevel) {
|
||||
return left.mEntryLevel > right.mEntryLevel;
|
||||
}
|
||||
for (int i = 0; i < left.mEntryLevel; ++i) {
|
||||
if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
|
||||
return left.mPrevWordIds[i] < right.mPrevWordIds[i];
|
||||
}
|
||||
}
|
||||
// left and rigth represent the same entry.
|
||||
return false;
|
||||
}
|
||||
|
||||
LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
|
||||
const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
|
||||
: mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
|
||||
memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
|
||||
}
|
||||
|
||||
} // namespace latinime
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
|
||||
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
|
||||
#include "defines.h"
|
||||
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
|
||||
|
@ -77,13 +78,43 @@ class LanguageModelDictContent {
|
|||
|
||||
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
|
||||
int *const outEntryCounts) {
|
||||
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
||||
outEntryCounts[i] = 0;
|
||||
}
|
||||
return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
|
||||
headerPolicy, outEntryCounts);
|
||||
}
|
||||
|
||||
// entryCounts should be created by updateAllProbabilityEntries.
|
||||
bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
|
||||
const HeaderPolicy *const headerPolicy);
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
|
||||
|
||||
class EntryInfoToTurncate {
|
||||
public:
|
||||
class Comparator {
|
||||
public:
|
||||
bool operator()(const EntryInfoToTurncate &left,
|
||||
const EntryInfoToTurncate &right) const;
|
||||
private:
|
||||
DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
|
||||
};
|
||||
|
||||
EntryInfoToTurncate(const int probability, const int timestamp, const int key,
|
||||
const int entryLevel, const int *const prevWordIds);
|
||||
|
||||
int mProbability;
|
||||
int mTimestamp;
|
||||
int mKey;
|
||||
int mEntryLevel;
|
||||
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||
|
||||
private:
|
||||
DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
|
||||
};
|
||||
|
||||
TrieMap mTrieMap;
|
||||
const bool mHasHistoricalInfo;
|
||||
|
||||
|
@ -94,6 +125,11 @@ class LanguageModelDictContent {
|
|||
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
|
||||
bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
|
||||
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
|
||||
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
|
||||
const int maxEntryCount, const int targetLevel);
|
||||
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
|
||||
const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
|
||||
std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
|
||||
};
|
||||
} // namespace latinime
|
||||
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
|
||||
|
|
|
@ -91,6 +91,21 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
|
|||
AKLOGE("Failed to update probabilities in language model dict content.");
|
||||
return false;
|
||||
}
|
||||
if (headerPolicy->isDecayingDict()) {
|
||||
int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||
maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount();
|
||||
maxEntryCountTable[1] = headerPolicy->getMaxBigramCount();
|
||||
for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
|
||||
// TODO: Have max n-gram count.
|
||||
maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
|
||||
}
|
||||
if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
|
||||
maxEntryCountTable, headerPolicy)) {
|
||||
AKLOGE("Failed to truncate entries in language model dict content.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
|
||||
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
|
||||
DynamicPtGcEventListeners
|
||||
|
@ -193,6 +208,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
|
|||
return true;
|
||||
}
|
||||
|
||||
// TODO: Remove.
|
||||
bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
|
||||
const Ver4PatriciaTrieNodeReader *const ptNodeReader,
|
||||
Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
|
||||
|
@ -233,6 +249,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
|
|||
return true;
|
||||
}
|
||||
|
||||
// TODO: Remove.
|
||||
bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) {
|
||||
const TerminalPositionLookupTable *const terminalPosLookupTable =
|
||||
mBuffers->getTerminalPositionLookupTable();
|
||||
|
|
Loading…
Reference in a new issue