Merge "Improve bigram probability computation for decaying dicts."

This commit is contained in:
Keisuke Kuroyanagi 2014-10-06 13:06:28 +00:00 committed by Android (Google) Code Review
commit 948ef10d03
6 changed files with 83 additions and 98 deletions

View file

@ -35,23 +35,15 @@ const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE
// count.
const char *const HeaderPolicy::HAS_HISTORICAL_INFO_KEY = "HAS_HISTORICAL_INFO";
const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY =
"FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP";
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
"FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
const char *const HeaderPolicy::FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY =
"FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS";
const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT";
const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT";
const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP = 2;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
// 30 days
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS =
30 * 24 * 60 * 60;
const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000;

View file

@ -53,15 +53,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY,
DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY,
DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)),
mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
@ -86,15 +80,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY,
DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY,
DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)),
mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
@ -113,12 +101,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount),
mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
mForgettingCurveOccurrencesToLevelUp(
headerPolicy->mForgettingCurveOccurrencesToLevelUp),
mForgettingCurveProbabilityValuesTableId(
headerPolicy->mForgettingCurveProbabilityValuesTableId),
mForgettingCurveDurationToLevelDown(
headerPolicy->mForgettingCurveDurationToLevelDown),
mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
mMaxBigramCount(headerPolicy->mMaxBigramCount),
mCodePointTable(headerPolicy->mCodePointTable) {}
@ -130,8 +114,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0),
mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
mForgettingCurveOccurrencesToLevelUp(0), mForgettingCurveProbabilityValuesTableId(0),
mForgettingCurveDurationToLevelDown(0), mMaxUnigramCount(0), mMaxBigramCount(0),
mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0),
mCodePointTable(nullptr) {}
~HeaderPolicy() {}
@ -217,18 +200,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return &mAttributeMap;
}
AK_FORCE_INLINE int getForgettingCurveOccurrencesToLevelUp() const {
return mForgettingCurveOccurrencesToLevelUp;
}
AK_FORCE_INLINE int getForgettingCurveProbabilityValuesTableId() const {
return mForgettingCurveProbabilityValuesTableId;
}
AK_FORCE_INLINE int getForgettingCurveDurationToLevelDown() const {
return mForgettingCurveDurationToLevelDown;
}
AK_FORCE_INLINE int getMaxUnigramCount() const {
return mMaxUnigramCount;
}
@ -280,9 +255,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const MAX_BIGRAM_COUNT_KEY;
static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
static const int DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP;
static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
static const int DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS;
static const int DEFAULT_MAX_UNIGRAM_COUNT;
static const int DEFAULT_MAX_BIGRAM_COUNT;
@ -300,9 +273,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const int mBigramCount;
const int mExtendedRegionSize;
const bool mHasHistoricalInfoOfWords;
const int mForgettingCurveOccurrencesToLevelUp;
const int mForgettingCurveProbabilityValuesTableId;
const int mForgettingCurveDurationToLevelDown;
const int mMaxUnigramCount;
const int mMaxBigramCount;
const int *const mCodePointTable;

View file

@ -146,18 +146,15 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probabi
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
const int bigramProbability) const {
if (mHeaderPolicy->isDecayingDict()) {
// Both probabilities are encoded. Decode them and get probability.
return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability);
} else {
// In the v4 format, bigramProbability is a conditional probability.
const int bigramConditionalProbability = bigramProbability;
if (unigramProbability == NOT_A_PROBABILITY) {
return NOT_A_PROBABILITY;
} else if (bigramProbability == NOT_A_PROBABILITY) {
}
if (bigramConditionalProbability == NOT_A_PROBABILITY) {
return ProbabilityUtils::backoff(unigramProbability);
} else {
return bigramProbability;
}
}
return bigramConditionalProbability;
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@ -170,37 +167,66 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
if (!prevWordIds.empty()) {
const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
if (prevWordIds.empty()) {
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
if (prevWordIds[0] == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY;
}
const PtNodeParams prevWordPtNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]);
if (prevWordPtNodeParams.isDeleted()) {
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos(
prevWordPtNodeParams.getTerminalId());
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
while (bigramsIt.hasNext()) {
bigramsIt.next();
if (bigramsIt.getBigramPos() == ptNodePos
&& bigramsIt.getProbability() != NOT_A_PROBABILITY) {
return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability());
const int bigramConditionalProbability = getBigramConditionalProbability(
prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
return getProbability(ptNodeParams.getProbability(), bigramConditionalProbability);
}
}
return NOT_A_PROBABILITY;
}
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
NgramListener *const listener) const {
if (prevWordIds.empty()) {
if (prevWordIds.firstOrDefault(NOT_A_DICT_POS) == NOT_A_DICT_POS) {
return;
}
const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
const PtNodeParams prevWordPtNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]);
if (prevWordPtNodeParams.isDeleted()) {
return;
}
const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos(
prevWordPtNodeParams.getTerminalId());
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
while (bigramsIt.hasNext()) {
bigramsIt.next();
listener->onVisitEntry(bigramsIt.getProbability(),
const int bigramConditionalProbability = getBigramConditionalProbability(
prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
listener->onVisitEntry(bigramConditionalProbability,
getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()));
}
}
int Ver4PatriciaTriePolicy::getBigramConditionalProbability(const int prevWordUnigramProbability,
const int bigramProbability) const {
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
// Calculate conditional probability.
return std::min(MAX_PROBABILITY - prevWordUnigramProbability + bigramProbability,
MAX_PROBABILITY);
} else {
// bigramProbability is a conditional probability.
return bigramProbability;
}
}
BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
const int wordId) const {
const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId));

View file

@ -174,6 +174,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getTerminalPtNodePosFromWordId(const int wordId) const;
const WordAttributes getWordAttributes(const int probability,
const PtNodeParams &ptNodeParams) const;
int getBigramConditionalProbability(const int prevWordUnigramProbability,
const int bigramProbability) const;
};
} // namespace v402
} // namespace backward

View file

@ -29,10 +29,14 @@ namespace latinime {
const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8;
const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60;
const int ForgettingCurveUtils::MAX_LEVEL = 3;
const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1;
const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15;
const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14;
const int ForgettingCurveUtils::MAX_LEVEL = 15;
const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 2;
const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 31;
const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 30;
const int ForgettingCurveUtils::OCCURRENCES_TO_RAISE_THE_LEVEL = 1;
// TODO: Evaluate whether this should be 7.5 days.
// 15 days
const int ForgettingCurveUtils::DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS = 15 * 24 * 60 * 60;
const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
@ -54,19 +58,23 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
|| (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel()
&& originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) {
// Initial information.
int count = newHistoricalInfo->getCount();
if (count >= OCCURRENCES_TO_RAISE_THE_LEVEL) {
const int level = clampToValidLevelRange(newHistoricalInfo->getLevel() + 1);
return HistoricalInfo(timestamp, level, 0 /* count */);
}
const int level = clampToValidLevelRange(newHistoricalInfo->getLevel());
const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
return HistoricalInfo(timestamp, level, count);
return HistoricalInfo(timestamp, level, clampToValidCountRange(count, headerPolicy));
} else {
const int updatedCount = originalHistoricalInfo->getCount() + 1;
if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) {
if (updatedCount >= OCCURRENCES_TO_RAISE_THE_LEVEL) {
// The count exceeds the max value the level can be incremented.
if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) {
// The level is already max.
return HistoricalInfo(timestamp,
originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount());
} else {
// Level up.
// Raise the level.
return HistoricalInfo(timestamp,
originalHistoricalInfo->getLevel() + 1, 0 /* count */);
}
@ -79,31 +87,18 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ int ForgettingCurveUtils::decodeProbability(
const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy) {
const int elapsedTimeStepCount = getElapsedTimeStepCount(historicalInfo->getTimestamp(),
headerPolicy->getForgettingCurveDurationToLevelDown());
DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS);
return sProbabilityTable.getProbability(
headerPolicy->getForgettingCurveProbabilityValuesTableId(),
clampToValidLevelRange(historicalInfo->getLevel()),
clampToValidTimeStepCountRange(elapsedTimeStepCount));
}
/* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability,
const int bigramProbability) {
if (unigramProbability == NOT_A_PROBABILITY) {
return NOT_A_PROBABILITY;
} else if (bigramProbability == NOT_A_PROBABILITY) {
return std::min(backoff(unigramProbability), MAX_PROBABILITY);
} else {
// TODO: Investigate better way to handle bigram probability.
return std::min(std::max(unigramProbability,
bigramProbability + MULTIPLIER_TWO_IN_PROBABILITY_SCALE), MAX_PROBABILITY);
}
}
/* static */ bool ForgettingCurveUtils::needsToKeep(const HistoricalInfo *const historicalInfo,
const HeaderPolicy *const headerPolicy) {
return historicalInfo->getLevel() > 0
|| getElapsedTimeStepCount(historicalInfo->getTimestamp(),
headerPolicy->getForgettingCurveDurationToLevelDown())
DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS)
< DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
}
@ -113,14 +108,14 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
if (originalHistoricalInfo->getTimestamp() == NOT_A_TIMESTAMP) {
return HistoricalInfo();
}
const int durationToLevelDownInSeconds = headerPolicy->getForgettingCurveDurationToLevelDown();
const int durationToLevelDownInSeconds = DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS;
const int elapsedTimeStep = getElapsedTimeStepCount(
originalHistoricalInfo->getTimestamp(), durationToLevelDownInSeconds);
if (elapsedTimeStep <= MAX_ELAPSED_TIME_STEP_COUNT) {
// No need to update historical info.
return *originalHistoricalInfo;
}
// Level down.
// Lower the level.
const int maxLevelDownAmonut = elapsedTimeStep / (MAX_ELAPSED_TIME_STEP_COUNT + 1);
const int levelDownAmount = (maxLevelDownAmonut >= originalHistoricalInfo->getLevel()) ?
originalHistoricalInfo->getLevel() : maxLevelDownAmonut;
@ -170,7 +165,7 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count,
const HeaderPolicy *const headerPolicy) {
return std::min(std::max(count, 0), headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1);
return std::min(std::max(count, 0), OCCURRENCES_TO_RAISE_THE_LEVEL - 1);
}
/* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) {
@ -187,9 +182,9 @@ const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID =
const int ForgettingCurveUtils::ProbabilityTable::STRONG_PROBABILITY_TABLE_ID = 2;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_PROBABILITY_TABLE_ID = 3;
const int ForgettingCurveUtils::ProbabilityTable::WEAK_MAX_PROBABILITY = 127;
const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 32;
const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 35;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 40;
const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 8;
const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 9;
const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 10;
ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
@ -202,7 +197,7 @@ ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
const float endProbability = getBaseProbabilityForLevel(tableId, level - 1);
for (int timeStepCount = 0; timeStepCount <= MAX_ELAPSED_TIME_STEP_COUNT;
++timeStepCount) {
if (level == 0) {
if (level < MIN_VISIBLE_LEVEL) {
mTables[tableId][level][timeStepCount] = NOT_A_PROBABILITY;
continue;
}

View file

@ -39,9 +39,6 @@ class ForgettingCurveUtils {
static int decodeProbability(const HistoricalInfo *const historicalInfo,
const HeaderPolicy *const headerPolicy);
static int getProbability(const int encodedUnigramProbability,
const int encodedBigramProbability);
static bool needsToKeep(const HistoricalInfo *const historicalInfo,
const HeaderPolicy *const headerPolicy);
@ -101,6 +98,8 @@ class ForgettingCurveUtils {
static const int MIN_VISIBLE_LEVEL;
static const int MAX_ELAPSED_TIME_STEP_COUNT;
static const int DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
static const int OCCURRENCES_TO_RAISE_THE_LEVEL;
static const int DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS;
static const float UNIGRAM_COUNT_HARD_LIMIT_WEIGHT;
static const float BIGRAM_COUNT_HARD_LIMIT_WEIGHT;