Merge "Use better conditional probability for ngram entries."

This commit is contained in:
Keisuke Kuroyanagi 2014-10-15 09:27:20 +00:00 committed by Android (Google) Code Review
commit 183e21c36c
2 changed files with 26 additions and 11 deletions

View file

@ -43,18 +43,18 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
const int wordId, const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
int maxLevel = 0;
int maxPrevWordCount = 0;
for (size_t i = 0; i < prevWordIds.size(); ++i) {
const int nextBitmapEntryIndex =
mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
break;
}
maxLevel = i + 1;
maxPrevWordCount = i + 1;
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
}
for (int i = maxLevel; i >= 0; --i) {
for (int i = maxPrevWordCount; i >= 0; --i) {
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
if (!result.mIsValid) {
continue;
@ -69,9 +69,24 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
// The entry should not be treated as a valid entry.
continue;
}
probability = std::min(rawProbability
+ ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
MAX_PROBABILITY);
if (i == 0) {
// unigram
probability = rawProbability;
} else {
const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
if (!prevWordProbabilityEntry.isValid()) {
continue;
}
if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
probability = rawProbability;
} else {
const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
probability = std::min(MAX_PROBABILITY - prevWordRawProbability
+ rawProbability, MAX_PROBABILITY);
}
}
} else {
probability = probabilityEntry.getProbability();
}

View file

@ -98,17 +98,17 @@ class ProbabilityEntry {
}
uint64_t encode(const bool hasHistoricalInfo) const {
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
uint64_t encodedEntry = static_cast<uint8_t>(mFlags);
if (hasHistoricalInfo) {
encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mHistoricalInfo.getTimestamp());
| static_cast<uint32_t>(mHistoricalInfo.getTimestamp());
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mHistoricalInfo.getLevel());
| static_cast<uint8_t>(mHistoricalInfo.getLevel());
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mHistoricalInfo.getCount());
| static_cast<uint8_t>(mHistoricalInfo.getCount());
} else {
encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mProbability);
| static_cast<uint8_t>(mProbability);
}
return encodedEntry;
}