am 094a8a68: Merge "Update probabilities in language model dict content for GC."
* commit '094a8a68e3ec4d7d5cacd9be08fa4558b096a93f': Update probabilities in language model dict content for GC.main
commit
64bf62e25e
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
|
||||
|
||||
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
|
||||
|
||||
namespace latinime {
|
||||
|
||||
bool LanguageModelDictContent::save(FILE *const file) const {
|
||||
|
@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
|
|||
return bitmapEntryIndex;
|
||||
}
|
||||
|
||||
bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
|
||||
const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
|
||||
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
|
||||
if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
|
||||
AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
|
||||
level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
|
||||
return false;
|
||||
}
|
||||
const ProbabilityEntry probabilityEntry =
|
||||
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
|
||||
if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
|
||||
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
|
||||
probabilityEntry.getHistoricalInfo(), headerPolicy);
|
||||
if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
|
||||
// Update the entry.
|
||||
const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
|
||||
if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
|
||||
bitmapEntryIndex)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// Remove the entry.
|
||||
if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!probabilityEntry.representsBeginningOfSentence()) {
|
||||
outEntryCounts[level] += 1;
|
||||
}
|
||||
if (!entry.hasNextLevelMap()) {
|
||||
continue;
|
||||
}
|
||||
if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
|
||||
headerPolicy, outEntryCounts)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace latinime
|
||||
|
|
|
@ -29,6 +29,8 @@
|
|||
|
||||
namespace latinime {
|
||||
|
||||
class HeaderPolicy;
|
||||
|
||||
/**
|
||||
* Class representing language model.
|
||||
*
|
||||
|
@ -73,6 +75,12 @@ class LanguageModelDictContent {
|
|||
|
||||
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
|
||||
|
||||
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
|
||||
int *const outEntryCounts) {
|
||||
return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
|
||||
headerPolicy, outEntryCounts);
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
|
||||
|
||||
|
@ -84,6 +92,8 @@ class LanguageModelDictContent {
|
|||
int *const outNgramCount);
|
||||
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
|
||||
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
|
||||
bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
|
||||
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
|
||||
};
|
||||
} // namespace latinime
|
||||
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
|
||||
|
|
|
@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
|
|||
const ProbabilityEntry originalProbabilityEntry =
|
||||
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
|
||||
toBeUpdatedPtNodeParams->getTerminalId());
|
||||
if (originalProbabilityEntry.hasHistoricalInfo()) {
|
||||
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
|
||||
originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
|
||||
const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
|
||||
&historicalInfo);
|
||||
if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
||||
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
|
||||
AKLOGE("Cannot write updated probability entry. terminalId: %d",
|
||||
toBeUpdatedPtNodeParams->getTerminalId());
|
||||
return false;
|
||||
if (originalProbabilityEntry.isValid()) {
|
||||
*outNeedsToKeepPtNode = true;
|
||||
return true;
|
||||
}
|
||||
const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy);
|
||||
if (!isValid) {
|
||||
if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
|
||||
AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
*outNeedsToKeepPtNode = isValid;
|
||||
} else {
|
||||
// No need to update probability.
|
||||
*outNeedsToKeepPtNode = true;
|
||||
}
|
||||
*outNeedsToKeepPtNode = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
|
|||
isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
|
||||
}
|
||||
|
||||
// TODO: Move probability handling code to LanguageModelDictContent.
|
||||
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
|
||||
const ProbabilityEntry *const originalProbabilityEntry,
|
||||
const ProbabilityEntry *const probabilityEntry) const {
|
||||
|
|
|
@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
|
|||
mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
|
||||
&shortcutPolicy);
|
||||
|
||||
int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
|
||||
entryCountTable)) {
|
||||
AKLOGE("Failed to update probabilities in language model dict content.");
|
||||
return false;
|
||||
}
|
||||
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
|
||||
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
|
||||
DynamicPtGcEventListeners
|
||||
|
|
|
@ -84,6 +84,10 @@ class TrieMap {
|
|||
return mValue;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const {
|
||||
return mNextLevelBitmapEntryIndex;
|
||||
}
|
||||
|
||||
private:
|
||||
const TrieMap *const mTrieMap;
|
||||
const int mKey;
|
||||
|
|
Loading…
Reference in New Issue