Merge "Update probabilities in language model dict content for GC."

This commit is contained in:
Keisuke Kuroyanagi 2014-08-25 03:08:57 +00:00 committed by Android (Google) Code Review
commit 094a8a68e3
5 changed files with 72 additions and 21 deletions

View file

@ -16,6 +16,8 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
namespace latinime {
bool LanguageModelDictContent::save(FILE *const file) const {
@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
return bitmapEntryIndex;
}
bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
return false;
}
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
probabilityEntry.getHistoricalInfo(), headerPolicy);
if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
// Update the entry.
const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
bitmapEntryIndex)) {
return false;
}
} else {
// Remove the entry.
if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
return false;
}
continue;
}
}
if (!probabilityEntry.representsBeginningOfSentence()) {
outEntryCounts[level] += 1;
}
if (!entry.hasNextLevelMap()) {
continue;
}
if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
headerPolicy, outEntryCounts)) {
return false;
}
}
return true;
}
} // namespace latinime

View file

@ -29,6 +29,8 @@
namespace latinime {
class HeaderPolicy;
/**
* Class representing language model.
*
@ -73,6 +75,12 @@ class LanguageModelDictContent {
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) {
return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
headerPolicy, outEntryCounts);
}
private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
@ -84,6 +92,8 @@ class LanguageModelDictContent {
int *const outNgramCount);
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */

View file

@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
const ProbabilityEntry originalProbabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId());
if (originalProbabilityEntry.hasHistoricalInfo()) {
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
&historicalInfo);
if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
AKLOGE("Cannot write updated probability entry. terminalId: %d",
toBeUpdatedPtNodeParams->getTerminalId());
return false;
}
const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy);
if (!isValid) {
if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
return false;
}
}
*outNeedsToKeepPtNode = isValid;
} else {
// No need to update probability.
if (originalProbabilityEntry.isValid()) {
*outNeedsToKeepPtNode = true;
return true;
}
if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
return false;
}
*outNeedsToKeepPtNode = false;
return true;
}
@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
// TODO: Move probability handling code to LanguageModelDictContent.
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry,
const ProbabilityEntry *const probabilityEntry) const {

View file

@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
&shortcutPolicy);
int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
entryCountTable)) {
AKLOGE("Failed to update probabilities in language model dict content.");
return false;
}
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
DynamicPtGcEventListeners

View file

@ -84,6 +84,10 @@ class TrieMap {
return mValue;
}
AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const {
return mNextLevelBitmapEntryIndex;
}
private:
const TrieMap *const mTrieMap;
const int mKey;