Use better conditional probability for ngram entries.

Old:
P(W | W_prev) = f(W, W_prev) + C
New:
P(W | W_prev) = f(W, W_prev) / f(W_prev)

Bug: 14425059
Bug: 16547409

Change-Id: I4d13be6de2c6bad6bad7fb22320a23ba4ecd361c
main
Keisuke Kuroyanagi 2014-10-15 18:23:00 +09:00
parent 5400701908
commit 72d17d9209
2 changed files with 26 additions and 11 deletions

View File

@ -43,18 +43,18 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
const int wordId, const HeaderPolicy *const headerPolicy) const { const int wordId, const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
int maxLevel = 0; int maxPrevWordCount = 0;
for (size_t i = 0; i < prevWordIds.size(); ++i) { for (size_t i = 0; i < prevWordIds.size(); ++i) {
const int nextBitmapEntryIndex = const int nextBitmapEntryIndex =
mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex; mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) { if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
break; break;
} }
maxLevel = i + 1; maxPrevWordCount = i + 1;
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
} }
for (int i = maxLevel; i >= 0; --i) { for (int i = maxPrevWordCount; i >= 0; --i) {
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
if (!result.mIsValid) { if (!result.mIsValid) {
continue; continue;
@ -69,9 +69,24 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
// The entry should not be treated as a valid entry. // The entry should not be treated as a valid entry.
continue; continue;
} }
probability = std::min(rawProbability if (i == 0) {
+ ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */), // unigram
MAX_PROBABILITY); probability = rawProbability;
} else {
const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
if (!prevWordProbabilityEntry.isValid()) {
continue;
}
if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
probability = rawProbability;
} else {
const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
probability = std::min(MAX_PROBABILITY - prevWordRawProbability
+ rawProbability, MAX_PROBABILITY);
}
}
} else { } else {
probability = probabilityEntry.getProbability(); probability = probabilityEntry.getProbability();
} }

View File

@ -98,17 +98,17 @@ class ProbabilityEntry {
} }
uint64_t encode(const bool hasHistoricalInfo) const { uint64_t encode(const bool hasHistoricalInfo) const {
uint64_t encodedEntry = static_cast<uint64_t>(mFlags); uint64_t encodedEntry = static_cast<uint8_t>(mFlags);
if (hasHistoricalInfo) { if (hasHistoricalInfo) {
encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT)) encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mHistoricalInfo.getTimestamp()); | static_cast<uint32_t>(mHistoricalInfo.getTimestamp());
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT)) encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mHistoricalInfo.getLevel()); | static_cast<uint8_t>(mHistoricalInfo.getLevel());
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mHistoricalInfo.getCount()); | static_cast<uint8_t>(mHistoricalInfo.getCount());
} else { } else {
encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT)) encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mProbability); | static_cast<uint8_t>(mProbability);
} }
return encodedEntry; return encodedEntry;
} }