Merge "Use better conditional probability for ngram entries."
This commit is contained in:
commit
183e21c36c
2 changed files with 26 additions and 11 deletions
|
@ -43,18 +43,18 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
|
||||||
const int wordId, const HeaderPolicy *const headerPolicy) const {
|
const int wordId, const HeaderPolicy *const headerPolicy) const {
|
||||||
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||||
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
|
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
|
||||||
int maxLevel = 0;
|
int maxPrevWordCount = 0;
|
||||||
for (size_t i = 0; i < prevWordIds.size(); ++i) {
|
for (size_t i = 0; i < prevWordIds.size(); ++i) {
|
||||||
const int nextBitmapEntryIndex =
|
const int nextBitmapEntryIndex =
|
||||||
mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
|
mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
|
||||||
if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
maxLevel = i + 1;
|
maxPrevWordCount = i + 1;
|
||||||
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
|
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = maxLevel; i >= 0; --i) {
|
for (int i = maxPrevWordCount; i >= 0; --i) {
|
||||||
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
|
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
|
||||||
if (!result.mIsValid) {
|
if (!result.mIsValid) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -69,9 +69,24 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
|
||||||
// The entry should not be treated as a valid entry.
|
// The entry should not be treated as a valid entry.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
probability = std::min(rawProbability
|
if (i == 0) {
|
||||||
+ ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
|
// unigram
|
||||||
MAX_PROBABILITY);
|
probability = rawProbability;
|
||||||
|
} else {
|
||||||
|
const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
|
||||||
|
prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
|
||||||
|
if (!prevWordProbabilityEntry.isValid()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
|
||||||
|
probability = rawProbability;
|
||||||
|
} else {
|
||||||
|
const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
|
||||||
|
prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
|
||||||
|
probability = std::min(MAX_PROBABILITY - prevWordRawProbability
|
||||||
|
+ rawProbability, MAX_PROBABILITY);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
probability = probabilityEntry.getProbability();
|
probability = probabilityEntry.getProbability();
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,17 +98,17 @@ class ProbabilityEntry {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t encode(const bool hasHistoricalInfo) const {
|
uint64_t encode(const bool hasHistoricalInfo) const {
|
||||||
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
|
uint64_t encodedEntry = static_cast<uint8_t>(mFlags);
|
||||||
if (hasHistoricalInfo) {
|
if (hasHistoricalInfo) {
|
||||||
encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
|
encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
|
||||||
^ static_cast<uint64_t>(mHistoricalInfo.getTimestamp());
|
| static_cast<uint32_t>(mHistoricalInfo.getTimestamp());
|
||||||
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
|
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
|
||||||
^ static_cast<uint64_t>(mHistoricalInfo.getLevel());
|
| static_cast<uint8_t>(mHistoricalInfo.getLevel());
|
||||||
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
|
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
|
||||||
^ static_cast<uint64_t>(mHistoricalInfo.getCount());
|
| static_cast<uint8_t>(mHistoricalInfo.getCount());
|
||||||
} else {
|
} else {
|
||||||
encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
|
encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
|
||||||
^ static_cast<uint64_t>(mProbability);
|
| static_cast<uint8_t>(mProbability);
|
||||||
}
|
}
|
||||||
return encodedEntry;
|
return encodedEntry;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue