Merge "Implement LanguageModelDictContent.getWordProbability()."

This commit is contained in:
Keisuke Kuroyanagi 2014-09-10 11:40:39 +00:00 committed by Android (Google) Code Review
commit 38085ee7ae
8 changed files with 95 additions and 40 deletions

View file

@ -354,7 +354,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
} }
bool addedNewBigram = false; bool addedNewBigram = false;
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos), if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, bigramProperty, &addedNewBigram)) { wordPos, bigramProperty, &addedNewBigram)) {
if (addedNewBigram) { if (addedNewBigram) {
mBigramCount++; mBigramCount++;
@ -396,7 +396,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
} }
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry( if (mUpdatingHelper.removeNgramEntry(
PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) { PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
mBigramCount--; mBigramCount--;
return true; return true;
} else { } else {

View file

@ -38,6 +38,40 @@ bool LanguageModelDictContent::runGC(
0 /* nextLevelBitmapEntryIndex */, outNgramCount); 0 /* nextLevelBitmapEntryIndex */, outNgramCount);
} }
int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
const int wordId) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
int maxLevel = 0;
for (size_t i = 0; i < prevWordIds.size(); ++i) {
const int nextBitmapEntryIndex =
mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
break;
}
maxLevel = i + 1;
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
}
for (int i = maxLevel; i >= 0; --i) {
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
if (!result.mIsValid) {
continue;
}
const int probability =
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
if (mHasHistoricalInfo) {
return std::min(
probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
MAX_PROBABILITY);
} else {
return probability;
}
}
// Cannot find the word.
return NOT_A_PROBABILITY;
}
ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
const WordIdArrayView prevWordIds, const int wordId) const { const WordIdArrayView prevWordIds, const int wordId) const {
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);

View file

@ -128,6 +128,8 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent, const LanguageModelDictContent *const originalContent,
int *const outNgramCount); int *const outNgramCount);
int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
ProbabilityEntry getProbabilityEntry(const int wordId) const { ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId); return getNgramProbabilityEntry(WordIdArrayView(), wordId);
} }

View file

@ -115,24 +115,12 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds, int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
const int wordId, MultiBigramMap *const multiBigramMap) const { const int wordId, MultiBigramMap *const multiBigramMap) const {
// TODO: Quit using MultiBigramMap.
if (wordId == NOT_A_WORD_ID) { if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
const int ptNodePos = // TODO: Support n-gram.
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); return mBuffers->getLanguageModelDictContent()->getWordProbability(
const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); WordIdArrayView::singleElementView(prevWordIds), wordId);
if (multiBigramMap) {
return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
wordId, ptNodeParams.getProbability());
}
if (prevWordIds) {
const int probability = getProbabilityOfWord(prevWordIds, wordId);
if (probability != NOT_A_PROBABILITY) {
return probability;
}
}
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
} }
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
@ -166,7 +154,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
// TODO: Support n-gram. // TODO: Support n-gram.
const ProbabilityEntry probabilityEntry = const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry( mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
IntArrayView::fromObject(prevWordIds), wordId); IntArrayView::singleElementView(prevWordIds), wordId);
if (!probabilityEntry.isValid()) { if (!probabilityEntry.isValid()) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
@ -194,7 +182,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
// TODO: Support n-gram. // TODO: Support n-gram.
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent(); const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
for (const auto entry : languageModelDictContent->getProbabilityEntries( for (const auto entry : languageModelDictContent->getProbabilityEntries(
WordIdArrayView::fromObject(prevWordIds))) { WordIdArrayView::singleElementView(prevWordIds))) {
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
const int probability = probabilityEntry.hasHistoricalInfo() ? const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability( ForgettingCurveUtils::decodeProbability(
@ -511,7 +499,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
// Fetch bigram information. // Fetch bigram information.
// TODO: Support n-gram. // TODO: Support n-gram.
std::vector<BigramProperty> bigrams; std::vector<BigramProperty> bigrams;
const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId); const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
int bigramWord1CodePoints[MAX_WORD_LENGTH]; int bigramWord1CodePoints[MAX_WORD_LENGTH];
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries( for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
prevWordIds)) { prevWordIds)) {

View file

@ -48,6 +48,11 @@ class ForgettingCurveUtils {
static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount, static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
const int bigramCount, const HeaderPolicy *const headerPolicy); const int bigramCount, const HeaderPolicy *const headerPolicy);
// TODO: Improve probability computation method and remove this.
static int getProbabilityBiasForNgram(const int n) {
return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
}
AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) { AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
return static_cast<int>(static_cast<float>(maxUnigramCount) return static_cast<int>(static_cast<float>(maxUnigramCount)
* UNIGRAM_COUNT_HARD_LIMIT_WEIGHT); * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);

View file

@ -61,9 +61,9 @@ class IntArrayView {
return IntArrayView(array, N); return IntArrayView(array, N);
} }
// Returns a view that points one int object. Does not take ownership of the given object. // Returns a view that points one int object.
AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) { AK_FORCE_INLINE static IntArrayView singleElementView(const int *const ptr) {
return IntArrayView(object, 1); return IntArrayView(ptr, 1);
} }
AK_FORCE_INLINE int operator[](const size_t index) const { AK_FORCE_INLINE int operator[](const size_t index) const {

View file

@ -26,28 +26,28 @@ namespace latinime {
namespace { namespace {
TEST(LanguageModelDictContentTest, TestUnigramProbability) { TEST(LanguageModelDictContentTest, TestUnigramProbability) {
LanguageModelDictContent LanguageModelDictContent(false /* useHistoricalInfo */); LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
const int flag = 0xFF; const int flag = 0xFF;
const int probability = 10; const int probability = 10;
const int wordId = 100; const int wordId = 100;
const ProbabilityEntry probabilityEntry(flag, probability); const ProbabilityEntry probabilityEntry(flag, probability);
LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
const ProbabilityEntry entry = const ProbabilityEntry entry =
LanguageModelDictContent.getProbabilityEntry(wordId); languageModelDictContent.getProbabilityEntry(wordId);
EXPECT_EQ(flag, entry.getFlags()); EXPECT_EQ(flag, entry.getFlags());
EXPECT_EQ(probability, entry.getProbability()); EXPECT_EQ(probability, entry.getProbability());
// Remove // Remove
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId)); EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid()); EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId)); EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry)); EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
EXPECT_TRUE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid()); EXPECT_TRUE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
} }
TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) { TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
LanguageModelDictContent LanguageModelDictContent(true /* useHistoricalInfo */); LanguageModelDictContent languageModelDictContent(true /* useHistoricalInfo */);
const int flag = 0xF0; const int flag = 0xF0;
const int timestamp = 0x3FFFFFFF; const int timestamp = 0x3FFFFFFF;
@ -56,19 +56,19 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
const int wordId = 100; const int wordId = 100;
const HistoricalInfo historicalInfo(timestamp, level, count); const HistoricalInfo historicalInfo(timestamp, level, count);
const ProbabilityEntry probabilityEntry(flag, &historicalInfo); const ProbabilityEntry probabilityEntry(flag, &historicalInfo);
LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId); const ProbabilityEntry entry = languageModelDictContent.getProbabilityEntry(wordId);
EXPECT_EQ(flag, entry.getFlags()); EXPECT_EQ(flag, entry.getFlags());
EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp()); EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp());
EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel()); EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel());
EXPECT_EQ(count, entry.getHistoricalInfo()->getCount()); EXPECT_EQ(count, entry.getHistoricalInfo()->getCount());
// Remove // Remove
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId)); EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid()); EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId)); EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry)); EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId)); EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
} }
TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) { TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
@ -89,5 +89,31 @@ TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
EXPECT_TRUE(wordIdSet.empty()); EXPECT_TRUE(wordIdSet.empty());
} }
TEST(LanguageModelDictContentTest, TestGetWordProbability) {
LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
const int flag = 0xFF;
const int probability = 10;
const int bigramProbability = 20;
const int trigramProbability = 30;
const int wordId = 100;
const int prevWordIdArray[] = { 1, 2 };
const WordIdArrayView prevWordIds = WordIdArrayView::fromFixedSizeArray(prevWordIdArray);
const ProbabilityEntry probabilityEntry(flag, probability);
languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
const ProbabilityEntry bigramProbabilityEntry(flag, bigramProbability);
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
&bigramProbabilityEntry);
EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
prevWordIds[1], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
&trigramProbabilityEntry);
EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
}
} // namespace } // namespace
} // namespace latinime } // namespace latinime

View file

@ -52,7 +52,7 @@ TEST(IntArrayViewTest, TestConstructFromArray) {
TEST(IntArrayViewTest, TestConstructFromObject) { TEST(IntArrayViewTest, TestConstructFromObject) {
const int object = 10; const int object = 10;
const auto intArrayView = IntArrayView::fromObject(&object); const auto intArrayView = IntArrayView::singleElementView(&object);
EXPECT_EQ(1u, intArrayView.size()); EXPECT_EQ(1u, intArrayView.size());
EXPECT_EQ(object, intArrayView[0]); EXPECT_EQ(object, intArrayView[0]);
} }