Update probabilities in language model dict content for GC.

Bug: 14425059
Change-Id: I354408afd8e5c1955ff0acea3d0243d628fe3843
main
Keisuke Kuroyanagi 2014-08-22 20:07:54 +09:00
parent cdc260b78e
commit 9aa6699107
5 changed files with 72 additions and 21 deletions

View File

@ -16,6 +16,8 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
namespace latinime { namespace latinime {
bool LanguageModelDictContent::save(FILE *const file) const { bool LanguageModelDictContent::save(FILE *const file) const {
@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
return bitmapEntryIndex; return bitmapEntryIndex;
} }
bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
return false;
}
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
probabilityEntry.getHistoricalInfo(), headerPolicy);
if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
// Update the entry.
const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
bitmapEntryIndex)) {
return false;
}
} else {
// Remove the entry.
if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
return false;
}
continue;
}
}
if (!probabilityEntry.representsBeginningOfSentence()) {
outEntryCounts[level] += 1;
}
if (!entry.hasNextLevelMap()) {
continue;
}
if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
headerPolicy, outEntryCounts)) {
return false;
}
}
return true;
}
} // namespace latinime } // namespace latinime

View File

@ -29,6 +29,8 @@
namespace latinime { namespace latinime {
class HeaderPolicy;
/** /**
* Class representing language model. * Class representing language model.
* *
@ -73,6 +75,12 @@ class LanguageModelDictContent {
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId); bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) {
return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
headerPolicy, outEntryCounts);
}
private: private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
@ -84,6 +92,8 @@ class LanguageModelDictContent {
int *const outNgramCount); int *const outNgramCount);
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
}; };
} // namespace latinime } // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */

View File

@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
const ProbabilityEntry originalProbabilityEntry = const ProbabilityEntry originalProbabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry( mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId()); toBeUpdatedPtNodeParams->getTerminalId());
if (originalProbabilityEntry.hasHistoricalInfo()) { if (originalProbabilityEntry.isValid()) {
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( *outNeedsToKeepPtNode = true;
originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy); return true;
const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
&historicalInfo);
if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
AKLOGE("Cannot write updated probability entry. terminalId: %d",
toBeUpdatedPtNodeParams->getTerminalId());
return false;
} }
const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy);
if (!isValid) {
if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
return false; return false;
} }
} *outNeedsToKeepPtNode = false;
*outNeedsToKeepPtNode = isValid;
} else {
// No need to update probability.
*outNeedsToKeepPtNode = true;
}
return true; return true;
} }
@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
} }
// TODO: Move probability handling code to LanguageModelDictContent.
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry, const ProbabilityEntry *const originalProbabilityEntry,
const ProbabilityEntry *const probabilityEntry) const { const ProbabilityEntry *const probabilityEntry) const {

View File

@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy, mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
&shortcutPolicy); &shortcutPolicy);
int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
entryCountTable)) {
AKLOGE("Failed to update probabilities in language model dict content.");
return false;
}
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
DynamicPtGcEventListeners DynamicPtGcEventListeners

View File

@ -84,6 +84,10 @@ class TrieMap {
return mValue; return mValue;
} }
AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const {
return mNextLevelBitmapEntryIndex;
}
private: private:
const TrieMap *const mTrieMap; const TrieMap *const mTrieMap;
const int mKey; const int mKey;