Merge "Support decaying dict in getWordProbability()."
This commit is contained in:
commit
627b0107ea
4 changed files with 18 additions and 13 deletions
|
@ -39,7 +39,7 @@ bool LanguageModelDictContent::runGC(
|
|||
}
|
||||
|
||||
int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
|
||||
const int wordId) const {
|
||||
const int wordId, const HeaderPolicy *const headerPolicy) const {
|
||||
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
|
||||
int maxLevel = 0;
|
||||
|
@ -58,14 +58,15 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI
|
|||
if (!result.mIsValid) {
|
||||
continue;
|
||||
}
|
||||
const int probability =
|
||||
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
|
||||
const ProbabilityEntry probabilityEntry =
|
||||
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
|
||||
if (mHasHistoricalInfo) {
|
||||
return std::min(
|
||||
probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
|
||||
MAX_PROBABILITY);
|
||||
const int probability = ForgettingCurveUtils::decodeProbability(
|
||||
probabilityEntry.getHistoricalInfo(), headerPolicy)
|
||||
+ ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */);
|
||||
return std::min(probability, MAX_PROBABILITY);
|
||||
} else {
|
||||
return probability;
|
||||
return probabilityEntry.getProbability();
|
||||
}
|
||||
}
|
||||
// Cannot find the word.
|
||||
|
|
|
@ -128,7 +128,8 @@ class LanguageModelDictContent {
|
|||
const LanguageModelDictContent *const originalContent,
|
||||
int *const outNgramCount);
|
||||
|
||||
int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
|
||||
int getWordProbability(const WordIdArrayView prevWordIds, const int wordId,
|
||||
const HeaderPolicy *const headerPolicy) const;
|
||||
|
||||
ProbabilityEntry getProbabilityEntry(const int wordId) const {
|
||||
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
|
||||
|
|
|
@ -121,9 +121,10 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
|||
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
|
||||
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
||||
// TODO: Support n-gram.
|
||||
return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
|
||||
prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(),
|
||||
ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
|
||||
const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability(
|
||||
prevWordIds.limit(1 /* maxSize */), wordId, mHeaderPolicy);
|
||||
return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
|
||||
probability == 0);
|
||||
}
|
||||
|
||||
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
|
||||
|
|
|
@ -107,13 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
|
|||
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
|
||||
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
|
||||
&bigramProbabilityEntry);
|
||||
EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
|
||||
EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
|
||||
nullptr /* headerPolicy */));
|
||||
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
|
||||
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
|
||||
prevWordIds[1], &probabilityEntry);
|
||||
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
|
||||
&trigramProbabilityEntry);
|
||||
EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
|
||||
EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
|
||||
nullptr /* headerPolicy */));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
Loading…
Reference in a new issue