am 627b0107: Merge "Support decaying dict in getWordProbability()."

* commit '627b0107ea813622bb32b77bc3ab335942fb89f8':
  Support decaying dict in getWordProbability().
main
Keisuke Kuroyanagi 2014-09-16 12:04:55 +00:00 committed by Android Git Automerger
commit 2fa6d02a53
4 changed files with 18 additions and 13 deletions

View File

@ -39,7 +39,7 @@ bool LanguageModelDictContent::runGC(
} }
int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds, int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
const int wordId) const { const int wordId, const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
int maxLevel = 0; int maxLevel = 0;
@ -58,14 +58,15 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI
if (!result.mIsValid) { if (!result.mIsValid) {
continue; continue;
} }
const int probability = const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability(); ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
if (mHasHistoricalInfo) { if (mHasHistoricalInfo) {
return std::min( const int probability = ForgettingCurveUtils::decodeProbability(
probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */), probabilityEntry.getHistoricalInfo(), headerPolicy)
MAX_PROBABILITY); + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */);
return std::min(probability, MAX_PROBABILITY);
} else { } else {
return probability; return probabilityEntry.getProbability();
} }
} }
// Cannot find the word. // Cannot find the word.

View File

@ -128,7 +128,8 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent, const LanguageModelDictContent *const originalContent,
int *const outNgramCount); int *const outNgramCount);
int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const; int getWordProbability(const WordIdArrayView prevWordIds, const int wordId,
const HeaderPolicy *const headerPolicy) const;
ProbabilityEntry getProbabilityEntry(const int wordId) const { ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId); return getNgramProbabilityEntry(WordIdArrayView(), wordId);

View File

@ -121,9 +121,10 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
// TODO: Support n-gram. // TODO: Support n-gram.
return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability( const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability(
prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(), prevWordIds.limit(1 /* maxSize */), wordId, mHeaderPolicy);
ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0); return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
probability == 0);
} }
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,

View File

@ -107,13 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry); languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId, languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
&bigramProbabilityEntry); &bigramProbabilityEntry);
EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId)); EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
nullptr /* headerPolicy */));
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability); const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
prevWordIds[1], &probabilityEntry); prevWordIds[1], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId, languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
&trigramProbabilityEntry); &trigramProbabilityEntry);
EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId)); EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
nullptr /* headerPolicy */));
} }
} // namespace } // namespace