am 183e21c3
: Merge "Use better conditional probability for ngram entries."
* commit '183e21c36cd5e05852733508bef317290e5e51ce': Use better conditional probability for ngram entries.
This commit is contained in:
commit
de0d34a1f9
2 changed files with 26 additions and 11 deletions
|
@ -43,18 +43,18 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
|
|||
const int wordId, const HeaderPolicy *const headerPolicy) const {
|
||||
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
|
||||
int maxLevel = 0;
|
||||
int maxPrevWordCount = 0;
|
||||
for (size_t i = 0; i < prevWordIds.size(); ++i) {
|
||||
const int nextBitmapEntryIndex =
|
||||
mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
|
||||
if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
||||
break;
|
||||
}
|
||||
maxLevel = i + 1;
|
||||
maxPrevWordCount = i + 1;
|
||||
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
|
||||
}
|
||||
|
||||
for (int i = maxLevel; i >= 0; --i) {
|
||||
for (int i = maxPrevWordCount; i >= 0; --i) {
|
||||
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
|
||||
if (!result.mIsValid) {
|
||||
continue;
|
||||
|
@ -69,9 +69,24 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
|
|||
// The entry should not be treated as a valid entry.
|
||||
continue;
|
||||
}
|
||||
probability = std::min(rawProbability
|
||||
+ ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
|
||||
MAX_PROBABILITY);
|
||||
if (i == 0) {
|
||||
// unigram
|
||||
probability = rawProbability;
|
||||
} else {
|
||||
const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
|
||||
prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
|
||||
if (!prevWordProbabilityEntry.isValid()) {
|
||||
continue;
|
||||
}
|
||||
if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
|
||||
probability = rawProbability;
|
||||
} else {
|
||||
const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
|
||||
prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
|
||||
probability = std::min(MAX_PROBABILITY - prevWordRawProbability
|
||||
+ rawProbability, MAX_PROBABILITY);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
probability = probabilityEntry.getProbability();
|
||||
}
|
||||
|
|
|
@ -98,17 +98,17 @@ class ProbabilityEntry {
|
|||
}
|
||||
|
||||
uint64_t encode(const bool hasHistoricalInfo) const {
|
||||
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
|
||||
uint64_t encodedEntry = static_cast<uint8_t>(mFlags);
|
||||
if (hasHistoricalInfo) {
|
||||
encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
|
||||
^ static_cast<uint64_t>(mHistoricalInfo.getTimestamp());
|
||||
| static_cast<uint32_t>(mHistoricalInfo.getTimestamp());
|
||||
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
|
||||
^ static_cast<uint64_t>(mHistoricalInfo.getLevel());
|
||||
| static_cast<uint8_t>(mHistoricalInfo.getLevel());
|
||||
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
|
||||
^ static_cast<uint64_t>(mHistoricalInfo.getCount());
|
||||
| static_cast<uint8_t>(mHistoricalInfo.getCount());
|
||||
} else {
|
||||
encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
|
||||
^ static_cast<uint64_t>(mProbability);
|
||||
| static_cast<uint8_t>(mProbability);
|
||||
}
|
||||
return encodedEntry;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue