am 627b0107: Merge "Support decaying dict in getWordProbability()."
* commit '627b0107ea813622bb32b77bc3ab335942fb89f8': Support decaying dict in getWordProbability().main
commit
2fa6d02a53
|
@ -39,7 +39,7 @@ bool LanguageModelDictContent::runGC(
|
||||||
}
|
}
|
||||||
|
|
||||||
int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
|
int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
|
||||||
const int wordId) const {
|
const int wordId, const HeaderPolicy *const headerPolicy) const {
|
||||||
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||||
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
|
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
|
||||||
int maxLevel = 0;
|
int maxLevel = 0;
|
||||||
|
@ -58,14 +58,15 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI
|
||||||
if (!result.mIsValid) {
|
if (!result.mIsValid) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const int probability =
|
const ProbabilityEntry probabilityEntry =
|
||||||
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
|
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
|
||||||
if (mHasHistoricalInfo) {
|
if (mHasHistoricalInfo) {
|
||||||
return std::min(
|
const int probability = ForgettingCurveUtils::decodeProbability(
|
||||||
probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
|
probabilityEntry.getHistoricalInfo(), headerPolicy)
|
||||||
MAX_PROBABILITY);
|
+ ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */);
|
||||||
|
return std::min(probability, MAX_PROBABILITY);
|
||||||
} else {
|
} else {
|
||||||
return probability;
|
return probabilityEntry.getProbability();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Cannot find the word.
|
// Cannot find the word.
|
||||||
|
|
|
@ -128,7 +128,8 @@ class LanguageModelDictContent {
|
||||||
const LanguageModelDictContent *const originalContent,
|
const LanguageModelDictContent *const originalContent,
|
||||||
int *const outNgramCount);
|
int *const outNgramCount);
|
||||||
|
|
||||||
int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
|
int getWordProbability(const WordIdArrayView prevWordIds, const int wordId,
|
||||||
|
const HeaderPolicy *const headerPolicy) const;
|
||||||
|
|
||||||
ProbabilityEntry getProbabilityEntry(const int wordId) const {
|
ProbabilityEntry getProbabilityEntry(const int wordId) const {
|
||||||
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
|
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
|
||||||
|
|
|
@ -121,9 +121,10 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
||||||
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
|
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
|
||||||
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
||||||
// TODO: Support n-gram.
|
// TODO: Support n-gram.
|
||||||
return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
|
const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability(
|
||||||
prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(),
|
prevWordIds.limit(1 /* maxSize */), wordId, mHeaderPolicy);
|
||||||
ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
|
return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
|
||||||
|
probability == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
|
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
|
||||||
|
|
|
@ -107,13 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
|
||||||
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
|
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
|
||||||
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
|
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
|
||||||
&bigramProbabilityEntry);
|
&bigramProbabilityEntry);
|
||||||
EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
|
EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
|
||||||
|
nullptr /* headerPolicy */));
|
||||||
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
|
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
|
||||||
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
|
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
|
||||||
prevWordIds[1], &probabilityEntry);
|
prevWordIds[1], &probabilityEntry);
|
||||||
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
|
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
|
||||||
&trigramProbabilityEntry);
|
&trigramProbabilityEntry);
|
||||||
EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
|
EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
|
||||||
|
nullptr /* headerPolicy */));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
Loading…
Reference in New Issue