Merge "Improve bigram probability computation for decaying dicts."

This commit is contained in:
Keisuke Kuroyanagi 2014-10-06 13:06:28 +00:00 committed by Android (Google) Code Review
commit 948ef10d03
6 changed files with 83 additions and 98 deletions

View file

@ -35,23 +35,15 @@ const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE
// count. // count.
const char *const HeaderPolicy::HAS_HISTORICAL_INFO_KEY = "HAS_HISTORICAL_INFO"; const char *const HeaderPolicy::HAS_HISTORICAL_INFO_KEY = "HAS_HISTORICAL_INFO";
const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY =
"FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP";
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY = const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
"FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID"; "FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
const char *const HeaderPolicy::FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY =
"FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS";
const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT"; const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT";
const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT"; const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT";
const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100; const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP = 2;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3; const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
// 30 days
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS =
30 * 24 * 60 * 60;
const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000; const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000; const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000;

View file

@ -53,15 +53,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)), EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)), &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY,
DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue( mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY, &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)), DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY,
DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)),
mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue( mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
@ -86,15 +80,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0), mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)), &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY,
DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue( mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY, &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)), DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY,
DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)),
mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue( mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
@ -113,12 +101,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount), mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount),
mExtendedRegionSize(headerPolicy->mExtendedRegionSize), mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords), mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
mForgettingCurveOccurrencesToLevelUp(
headerPolicy->mForgettingCurveOccurrencesToLevelUp),
mForgettingCurveProbabilityValuesTableId( mForgettingCurveProbabilityValuesTableId(
headerPolicy->mForgettingCurveProbabilityValuesTableId), headerPolicy->mForgettingCurveProbabilityValuesTableId),
mForgettingCurveDurationToLevelDown(
headerPolicy->mForgettingCurveDurationToLevelDown),
mMaxUnigramCount(headerPolicy->mMaxUnigramCount), mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
mMaxBigramCount(headerPolicy->mMaxBigramCount), mMaxBigramCount(headerPolicy->mMaxBigramCount),
mCodePointTable(headerPolicy->mCodePointTable) {} mCodePointTable(headerPolicy->mCodePointTable) {}
@ -130,8 +114,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false), mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0),
mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
mForgettingCurveOccurrencesToLevelUp(0), mForgettingCurveProbabilityValuesTableId(0), mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0),
mForgettingCurveDurationToLevelDown(0), mMaxUnigramCount(0), mMaxBigramCount(0),
mCodePointTable(nullptr) {} mCodePointTable(nullptr) {}
~HeaderPolicy() {} ~HeaderPolicy() {}
@ -217,18 +200,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return &mAttributeMap; return &mAttributeMap;
} }
AK_FORCE_INLINE int getForgettingCurveOccurrencesToLevelUp() const {
return mForgettingCurveOccurrencesToLevelUp;
}
AK_FORCE_INLINE int getForgettingCurveProbabilityValuesTableId() const { AK_FORCE_INLINE int getForgettingCurveProbabilityValuesTableId() const {
return mForgettingCurveProbabilityValuesTableId; return mForgettingCurveProbabilityValuesTableId;
} }
AK_FORCE_INLINE int getForgettingCurveDurationToLevelDown() const {
return mForgettingCurveDurationToLevelDown;
}
AK_FORCE_INLINE int getMaxUnigramCount() const { AK_FORCE_INLINE int getMaxUnigramCount() const {
return mMaxUnigramCount; return mMaxUnigramCount;
} }
@ -280,9 +255,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const MAX_BIGRAM_COUNT_KEY; static const char *const MAX_BIGRAM_COUNT_KEY;
static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE; static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
static const int DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP;
static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID; static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
static const int DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS;
static const int DEFAULT_MAX_UNIGRAM_COUNT; static const int DEFAULT_MAX_UNIGRAM_COUNT;
static const int DEFAULT_MAX_BIGRAM_COUNT; static const int DEFAULT_MAX_BIGRAM_COUNT;
@ -300,9 +273,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const int mBigramCount; const int mBigramCount;
const int mExtendedRegionSize; const int mExtendedRegionSize;
const bool mHasHistoricalInfoOfWords; const bool mHasHistoricalInfoOfWords;
const int mForgettingCurveOccurrencesToLevelUp;
const int mForgettingCurveProbabilityValuesTableId; const int mForgettingCurveProbabilityValuesTableId;
const int mForgettingCurveDurationToLevelDown;
const int mMaxUnigramCount; const int mMaxUnigramCount;
const int mMaxBigramCount; const int mMaxBigramCount;
const int *const mCodePointTable; const int *const mCodePointTable;

View file

@ -146,18 +146,15 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probabi
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
const int bigramProbability) const { const int bigramProbability) const {
if (mHeaderPolicy->isDecayingDict()) { // In the v4 format, bigramProbability is a conditional probability.
// Both probabilities are encoded. Decode them and get probability. const int bigramConditionalProbability = bigramProbability;
return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability); if (unigramProbability == NOT_A_PROBABILITY) {
} else { return NOT_A_PROBABILITY;
if (unigramProbability == NOT_A_PROBABILITY) {
return NOT_A_PROBABILITY;
} else if (bigramProbability == NOT_A_PROBABILITY) {
return ProbabilityUtils::backoff(unigramProbability);
} else {
return bigramProbability;
}
} }
if (bigramConditionalProbability == NOT_A_PROBABILITY) {
return ProbabilityUtils::backoff(unigramProbability);
}
return bigramConditionalProbability;
} }
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@ -170,37 +167,66 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
if (!prevWordIds.empty()) { if (prevWordIds.empty()) {
const int bigramsPosition = getBigramsPositionOfPtNode( return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
getTerminalPtNodePosFromWordId(prevWordIds[0])); }
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); if (prevWordIds[0] == NOT_A_WORD_ID) {
while (bigramsIt.hasNext()) {
bigramsIt.next();
if (bigramsIt.getBigramPos() == ptNodePos
&& bigramsIt.getProbability() != NOT_A_PROBABILITY) {
return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability());
}
}
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); const PtNodeParams prevWordPtNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]);
if (prevWordPtNodeParams.isDeleted()) {
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos(
prevWordPtNodeParams.getTerminalId());
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
while (bigramsIt.hasNext()) {
bigramsIt.next();
if (bigramsIt.getBigramPos() == ptNodePos
&& bigramsIt.getProbability() != NOT_A_PROBABILITY) {
const int bigramConditionalProbability = getBigramConditionalProbability(
prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
return getProbability(ptNodeParams.getProbability(), bigramConditionalProbability);
}
}
return NOT_A_PROBABILITY;
} }
void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
NgramListener *const listener) const { NgramListener *const listener) const {
if (prevWordIds.empty()) { if (prevWordIds.firstOrDefault(NOT_A_DICT_POS) == NOT_A_DICT_POS) {
return; return;
} }
const int bigramsPosition = getBigramsPositionOfPtNode( const PtNodeParams prevWordPtNodeParams =
getTerminalPtNodePosFromWordId(prevWordIds[0])); mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]);
if (prevWordPtNodeParams.isDeleted()) {
return;
}
const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos(
prevWordPtNodeParams.getTerminalId());
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
while (bigramsIt.hasNext()) { while (bigramsIt.hasNext()) {
bigramsIt.next(); bigramsIt.next();
listener->onVisitEntry(bigramsIt.getProbability(), const int bigramConditionalProbability = getBigramConditionalProbability(
prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
listener->onVisitEntry(bigramConditionalProbability,
getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos())); getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()));
} }
} }
int Ver4PatriciaTriePolicy::getBigramConditionalProbability(const int prevWordUnigramProbability,
const int bigramProbability) const {
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
// Calculate conditional probability.
return std::min(MAX_PROBABILITY - prevWordUnigramProbability + bigramProbability,
MAX_PROBABILITY);
} else {
// bigramProbability is a conditional probability.
return bigramProbability;
}
}
BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
const int wordId) const { const int wordId) const {
const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId)); const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId));

View file

@ -174,6 +174,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getTerminalPtNodePosFromWordId(const int wordId) const; int getTerminalPtNodePosFromWordId(const int wordId) const;
const WordAttributes getWordAttributes(const int probability, const WordAttributes getWordAttributes(const int probability,
const PtNodeParams &ptNodeParams) const; const PtNodeParams &ptNodeParams) const;
int getBigramConditionalProbability(const int prevWordUnigramProbability,
const int bigramProbability) const;
}; };
} // namespace v402 } // namespace v402
} // namespace backward } // namespace backward

View file

@ -29,10 +29,14 @@ namespace latinime {
const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8; const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8;
const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60; const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60;
const int ForgettingCurveUtils::MAX_LEVEL = 3; const int ForgettingCurveUtils::MAX_LEVEL = 15;
const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1; const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 2;
const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15; const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 31;
const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14; const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 30;
const int ForgettingCurveUtils::OCCURRENCES_TO_RAISE_THE_LEVEL = 1;
// TODO: Evaluate whether this should be 7.5 days.
// 15 days
const int ForgettingCurveUtils::DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS = 15 * 24 * 60 * 60;
const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2; const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2; const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
@ -54,19 +58,23 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
|| (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel() || (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel()
&& originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) { && originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) {
// Initial information. // Initial information.
int count = newHistoricalInfo->getCount();
if (count >= OCCURRENCES_TO_RAISE_THE_LEVEL) {
const int level = clampToValidLevelRange(newHistoricalInfo->getLevel() + 1);
return HistoricalInfo(timestamp, level, 0 /* count */);
}
const int level = clampToValidLevelRange(newHistoricalInfo->getLevel()); const int level = clampToValidLevelRange(newHistoricalInfo->getLevel());
const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy); return HistoricalInfo(timestamp, level, clampToValidCountRange(count, headerPolicy));
return HistoricalInfo(timestamp, level, count);
} else { } else {
const int updatedCount = originalHistoricalInfo->getCount() + 1; const int updatedCount = originalHistoricalInfo->getCount() + 1;
if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) { if (updatedCount >= OCCURRENCES_TO_RAISE_THE_LEVEL) {
// The count exceeds the max value the level can be incremented. // The count exceeds the max value the level can be incremented.
if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) { if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) {
// The level is already max. // The level is already max.
return HistoricalInfo(timestamp, return HistoricalInfo(timestamp,
originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount()); originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount());
} else { } else {
// Level up. // Raise the level.
return HistoricalInfo(timestamp, return HistoricalInfo(timestamp,
originalHistoricalInfo->getLevel() + 1, 0 /* count */); originalHistoricalInfo->getLevel() + 1, 0 /* count */);
} }
@ -79,31 +87,18 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ int ForgettingCurveUtils::decodeProbability( /* static */ int ForgettingCurveUtils::decodeProbability(
const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy) { const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy) {
const int elapsedTimeStepCount = getElapsedTimeStepCount(historicalInfo->getTimestamp(), const int elapsedTimeStepCount = getElapsedTimeStepCount(historicalInfo->getTimestamp(),
headerPolicy->getForgettingCurveDurationToLevelDown()); DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS);
return sProbabilityTable.getProbability( return sProbabilityTable.getProbability(
headerPolicy->getForgettingCurveProbabilityValuesTableId(), headerPolicy->getForgettingCurveProbabilityValuesTableId(),
clampToValidLevelRange(historicalInfo->getLevel()), clampToValidLevelRange(historicalInfo->getLevel()),
clampToValidTimeStepCountRange(elapsedTimeStepCount)); clampToValidTimeStepCountRange(elapsedTimeStepCount));
} }
/* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability,
const int bigramProbability) {
if (unigramProbability == NOT_A_PROBABILITY) {
return NOT_A_PROBABILITY;
} else if (bigramProbability == NOT_A_PROBABILITY) {
return std::min(backoff(unigramProbability), MAX_PROBABILITY);
} else {
// TODO: Investigate better way to handle bigram probability.
return std::min(std::max(unigramProbability,
bigramProbability + MULTIPLIER_TWO_IN_PROBABILITY_SCALE), MAX_PROBABILITY);
}
}
/* static */ bool ForgettingCurveUtils::needsToKeep(const HistoricalInfo *const historicalInfo, /* static */ bool ForgettingCurveUtils::needsToKeep(const HistoricalInfo *const historicalInfo,
const HeaderPolicy *const headerPolicy) { const HeaderPolicy *const headerPolicy) {
return historicalInfo->getLevel() > 0 return historicalInfo->getLevel() > 0
|| getElapsedTimeStepCount(historicalInfo->getTimestamp(), || getElapsedTimeStepCount(historicalInfo->getTimestamp(),
headerPolicy->getForgettingCurveDurationToLevelDown()) DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS)
< DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD; < DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
} }
@ -113,14 +108,14 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
if (originalHistoricalInfo->getTimestamp() == NOT_A_TIMESTAMP) { if (originalHistoricalInfo->getTimestamp() == NOT_A_TIMESTAMP) {
return HistoricalInfo(); return HistoricalInfo();
} }
const int durationToLevelDownInSeconds = headerPolicy->getForgettingCurveDurationToLevelDown(); const int durationToLevelDownInSeconds = DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS;
const int elapsedTimeStep = getElapsedTimeStepCount( const int elapsedTimeStep = getElapsedTimeStepCount(
originalHistoricalInfo->getTimestamp(), durationToLevelDownInSeconds); originalHistoricalInfo->getTimestamp(), durationToLevelDownInSeconds);
if (elapsedTimeStep <= MAX_ELAPSED_TIME_STEP_COUNT) { if (elapsedTimeStep <= MAX_ELAPSED_TIME_STEP_COUNT) {
// No need to update historical info. // No need to update historical info.
return *originalHistoricalInfo; return *originalHistoricalInfo;
} }
// Level down. // Lower the level.
const int maxLevelDownAmonut = elapsedTimeStep / (MAX_ELAPSED_TIME_STEP_COUNT + 1); const int maxLevelDownAmonut = elapsedTimeStep / (MAX_ELAPSED_TIME_STEP_COUNT + 1);
const int levelDownAmount = (maxLevelDownAmonut >= originalHistoricalInfo->getLevel()) ? const int levelDownAmount = (maxLevelDownAmonut >= originalHistoricalInfo->getLevel()) ?
originalHistoricalInfo->getLevel() : maxLevelDownAmonut; originalHistoricalInfo->getLevel() : maxLevelDownAmonut;
@ -170,7 +165,7 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count, /* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count,
const HeaderPolicy *const headerPolicy) { const HeaderPolicy *const headerPolicy) {
return std::min(std::max(count, 0), headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1); return std::min(std::max(count, 0), OCCURRENCES_TO_RAISE_THE_LEVEL - 1);
} }
/* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) { /* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) {
@ -187,9 +182,9 @@ const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID =
const int ForgettingCurveUtils::ProbabilityTable::STRONG_PROBABILITY_TABLE_ID = 2; const int ForgettingCurveUtils::ProbabilityTable::STRONG_PROBABILITY_TABLE_ID = 2;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_PROBABILITY_TABLE_ID = 3; const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_PROBABILITY_TABLE_ID = 3;
const int ForgettingCurveUtils::ProbabilityTable::WEAK_MAX_PROBABILITY = 127; const int ForgettingCurveUtils::ProbabilityTable::WEAK_MAX_PROBABILITY = 127;
const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 32; const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 8;
const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 35; const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 9;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 40; const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 10;
ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() { ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
@ -202,7 +197,7 @@ ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
const float endProbability = getBaseProbabilityForLevel(tableId, level - 1); const float endProbability = getBaseProbabilityForLevel(tableId, level - 1);
for (int timeStepCount = 0; timeStepCount <= MAX_ELAPSED_TIME_STEP_COUNT; for (int timeStepCount = 0; timeStepCount <= MAX_ELAPSED_TIME_STEP_COUNT;
++timeStepCount) { ++timeStepCount) {
if (level == 0) { if (level < MIN_VISIBLE_LEVEL) {
mTables[tableId][level][timeStepCount] = NOT_A_PROBABILITY; mTables[tableId][level][timeStepCount] = NOT_A_PROBABILITY;
continue; continue;
} }

View file

@ -39,9 +39,6 @@ class ForgettingCurveUtils {
static int decodeProbability(const HistoricalInfo *const historicalInfo, static int decodeProbability(const HistoricalInfo *const historicalInfo,
const HeaderPolicy *const headerPolicy); const HeaderPolicy *const headerPolicy);
static int getProbability(const int encodedUnigramProbability,
const int encodedBigramProbability);
static bool needsToKeep(const HistoricalInfo *const historicalInfo, static bool needsToKeep(const HistoricalInfo *const historicalInfo,
const HeaderPolicy *const headerPolicy); const HeaderPolicy *const headerPolicy);
@ -101,6 +98,8 @@ class ForgettingCurveUtils {
static const int MIN_VISIBLE_LEVEL; static const int MIN_VISIBLE_LEVEL;
static const int MAX_ELAPSED_TIME_STEP_COUNT; static const int MAX_ELAPSED_TIME_STEP_COUNT;
static const int DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD; static const int DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
static const int OCCURRENCES_TO_RAISE_THE_LEVEL;
static const int DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS;
static const float UNIGRAM_COUNT_HARD_LIMIT_WEIGHT; static const float UNIGRAM_COUNT_HARD_LIMIT_WEIGHT;
static const float BIGRAM_COUNT_HARD_LIMIT_WEIGHT; static const float BIGRAM_COUNT_HARD_LIMIT_WEIGHT;