Merge "Update probabilities in language model dict content for GC."
This commit is contained in:
commit
094a8a68e3
5 changed files with 72 additions and 21 deletions
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
|
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
|
||||||
|
|
||||||
|
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
bool LanguageModelDictContent::save(FILE *const file) const {
|
bool LanguageModelDictContent::save(FILE *const file) const {
|
||||||
|
@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
|
||||||
return bitmapEntryIndex;
|
return bitmapEntryIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
|
||||||
|
const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
|
||||||
|
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
|
||||||
|
if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
|
||||||
|
AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
|
||||||
|
level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const ProbabilityEntry probabilityEntry =
|
||||||
|
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
|
||||||
|
if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
|
||||||
|
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
|
||||||
|
probabilityEntry.getHistoricalInfo(), headerPolicy);
|
||||||
|
if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
|
||||||
|
// Update the entry.
|
||||||
|
const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
|
||||||
|
if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
|
||||||
|
bitmapEntryIndex)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Remove the entry.
|
||||||
|
if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!probabilityEntry.representsBeginningOfSentence()) {
|
||||||
|
outEntryCounts[level] += 1;
|
||||||
|
}
|
||||||
|
if (!entry.hasNextLevelMap()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
|
||||||
|
headerPolicy, outEntryCounts)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
|
|
|
@ -29,6 +29,8 @@
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
|
class HeaderPolicy;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Class representing language model.
|
* Class representing language model.
|
||||||
*
|
*
|
||||||
|
@ -73,6 +75,12 @@ class LanguageModelDictContent {
|
||||||
|
|
||||||
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
|
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
|
||||||
|
|
||||||
|
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
|
||||||
|
int *const outEntryCounts) {
|
||||||
|
return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
|
||||||
|
headerPolicy, outEntryCounts);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
|
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
|
||||||
|
|
||||||
|
@ -84,6 +92,8 @@ class LanguageModelDictContent {
|
||||||
int *const outNgramCount);
|
int *const outNgramCount);
|
||||||
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
|
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
|
||||||
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
|
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
|
||||||
|
bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
|
||||||
|
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
|
||||||
};
|
};
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
|
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
|
||||||
|
|
|
@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
|
||||||
const ProbabilityEntry originalProbabilityEntry =
|
const ProbabilityEntry originalProbabilityEntry =
|
||||||
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
|
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
|
||||||
toBeUpdatedPtNodeParams->getTerminalId());
|
toBeUpdatedPtNodeParams->getTerminalId());
|
||||||
if (originalProbabilityEntry.hasHistoricalInfo()) {
|
if (originalProbabilityEntry.isValid()) {
|
||||||
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
|
*outNeedsToKeepPtNode = true;
|
||||||
originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
|
return true;
|
||||||
const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
|
|
||||||
&historicalInfo);
|
|
||||||
if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
|
||||||
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
|
|
||||||
AKLOGE("Cannot write updated probability entry. terminalId: %d",
|
|
||||||
toBeUpdatedPtNodeParams->getTerminalId());
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy);
|
|
||||||
if (!isValid) {
|
|
||||||
if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
|
if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
|
||||||
AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
|
AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
*outNeedsToKeepPtNode = false;
|
||||||
*outNeedsToKeepPtNode = isValid;
|
|
||||||
} else {
|
|
||||||
// No need to update probability.
|
|
||||||
*outNeedsToKeepPtNode = true;
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
|
||||||
isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
|
isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Move probability handling code to LanguageModelDictContent.
|
||||||
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
|
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
|
||||||
const ProbabilityEntry *const originalProbabilityEntry,
|
const ProbabilityEntry *const originalProbabilityEntry,
|
||||||
const ProbabilityEntry *const probabilityEntry) const {
|
const ProbabilityEntry *const probabilityEntry) const {
|
||||||
|
|
|
@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
|
||||||
mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
|
mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
|
||||||
&shortcutPolicy);
|
&shortcutPolicy);
|
||||||
|
|
||||||
|
int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||||
|
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
|
||||||
|
entryCountTable)) {
|
||||||
|
AKLOGE("Failed to update probabilities in language model dict content.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
|
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
|
||||||
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
|
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
|
||||||
DynamicPtGcEventListeners
|
DynamicPtGcEventListeners
|
||||||
|
|
|
@ -84,6 +84,10 @@ class TrieMap {
|
||||||
return mValue;
|
return mValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const {
|
||||||
|
return mNextLevelBitmapEntryIndex;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const TrieMap *const mTrieMap;
|
const TrieMap *const mTrieMap;
|
||||||
const int mKey;
|
const int mKey;
|
||||||
|
|
Loading…
Reference in a new issue