Merge "Implement LanguageModelDictContent.getWordProbability()."
This commit is contained in:
commit
38085ee7ae
8 changed files with 95 additions and 40 deletions
|
@ -354,7 +354,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
|
||||||
}
|
}
|
||||||
bool addedNewBigram = false;
|
bool addedNewBigram = false;
|
||||||
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
|
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
|
||||||
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos),
|
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
|
||||||
wordPos, bigramProperty, &addedNewBigram)) {
|
wordPos, bigramProperty, &addedNewBigram)) {
|
||||||
if (addedNewBigram) {
|
if (addedNewBigram) {
|
||||||
mBigramCount++;
|
mBigramCount++;
|
||||||
|
@ -396,7 +396,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
|
||||||
}
|
}
|
||||||
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
|
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
|
||||||
if (mUpdatingHelper.removeNgramEntry(
|
if (mUpdatingHelper.removeNgramEntry(
|
||||||
PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) {
|
PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
|
||||||
mBigramCount--;
|
mBigramCount--;
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -38,6 +38,40 @@ bool LanguageModelDictContent::runGC(
|
||||||
0 /* nextLevelBitmapEntryIndex */, outNgramCount);
|
0 /* nextLevelBitmapEntryIndex */, outNgramCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
|
||||||
|
const int wordId) const {
|
||||||
|
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
||||||
|
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
|
||||||
|
int maxLevel = 0;
|
||||||
|
for (size_t i = 0; i < prevWordIds.size(); ++i) {
|
||||||
|
const int nextBitmapEntryIndex =
|
||||||
|
mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
|
||||||
|
if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
maxLevel = i + 1;
|
||||||
|
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = maxLevel; i >= 0; --i) {
|
||||||
|
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
|
||||||
|
if (!result.mIsValid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const int probability =
|
||||||
|
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
|
||||||
|
if (mHasHistoricalInfo) {
|
||||||
|
return std::min(
|
||||||
|
probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
|
||||||
|
MAX_PROBABILITY);
|
||||||
|
} else {
|
||||||
|
return probability;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Cannot find the word.
|
||||||
|
return NOT_A_PROBABILITY;
|
||||||
|
}
|
||||||
|
|
||||||
ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
|
ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
|
||||||
const WordIdArrayView prevWordIds, const int wordId) const {
|
const WordIdArrayView prevWordIds, const int wordId) const {
|
||||||
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
|
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
|
||||||
|
|
|
@ -128,6 +128,8 @@ class LanguageModelDictContent {
|
||||||
const LanguageModelDictContent *const originalContent,
|
const LanguageModelDictContent *const originalContent,
|
||||||
int *const outNgramCount);
|
int *const outNgramCount);
|
||||||
|
|
||||||
|
int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
|
||||||
|
|
||||||
ProbabilityEntry getProbabilityEntry(const int wordId) const {
|
ProbabilityEntry getProbabilityEntry(const int wordId) const {
|
||||||
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
|
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
|
||||||
}
|
}
|
||||||
|
|
|
@ -115,24 +115,12 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
|
int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
|
||||||
const int wordId, MultiBigramMap *const multiBigramMap) const {
|
const int wordId, MultiBigramMap *const multiBigramMap) const {
|
||||||
// TODO: Quit using MultiBigramMap.
|
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return NOT_A_PROBABILITY;
|
return NOT_A_PROBABILITY;
|
||||||
}
|
}
|
||||||
const int ptNodePos =
|
// TODO: Support n-gram.
|
||||||
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
|
return mBuffers->getLanguageModelDictContent()->getWordProbability(
|
||||||
const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
|
WordIdArrayView::singleElementView(prevWordIds), wordId);
|
||||||
if (multiBigramMap) {
|
|
||||||
return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
|
|
||||||
wordId, ptNodeParams.getProbability());
|
|
||||||
}
|
|
||||||
if (prevWordIds) {
|
|
||||||
const int probability = getProbabilityOfWord(prevWordIds, wordId);
|
|
||||||
if (probability != NOT_A_PROBABILITY) {
|
|
||||||
return probability;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
||||||
|
@ -166,7 +154,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
|
||||||
// TODO: Support n-gram.
|
// TODO: Support n-gram.
|
||||||
const ProbabilityEntry probabilityEntry =
|
const ProbabilityEntry probabilityEntry =
|
||||||
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
|
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
|
||||||
IntArrayView::fromObject(prevWordIds), wordId);
|
IntArrayView::singleElementView(prevWordIds), wordId);
|
||||||
if (!probabilityEntry.isValid()) {
|
if (!probabilityEntry.isValid()) {
|
||||||
return NOT_A_PROBABILITY;
|
return NOT_A_PROBABILITY;
|
||||||
}
|
}
|
||||||
|
@ -194,7 +182,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
|
||||||
// TODO: Support n-gram.
|
// TODO: Support n-gram.
|
||||||
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
|
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
|
||||||
for (const auto entry : languageModelDictContent->getProbabilityEntries(
|
for (const auto entry : languageModelDictContent->getProbabilityEntries(
|
||||||
WordIdArrayView::fromObject(prevWordIds))) {
|
WordIdArrayView::singleElementView(prevWordIds))) {
|
||||||
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
|
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
|
||||||
const int probability = probabilityEntry.hasHistoricalInfo() ?
|
const int probability = probabilityEntry.hasHistoricalInfo() ?
|
||||||
ForgettingCurveUtils::decodeProbability(
|
ForgettingCurveUtils::decodeProbability(
|
||||||
|
@ -511,7 +499,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
|
||||||
// Fetch bigram information.
|
// Fetch bigram information.
|
||||||
// TODO: Support n-gram.
|
// TODO: Support n-gram.
|
||||||
std::vector<BigramProperty> bigrams;
|
std::vector<BigramProperty> bigrams;
|
||||||
const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId);
|
const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
|
||||||
int bigramWord1CodePoints[MAX_WORD_LENGTH];
|
int bigramWord1CodePoints[MAX_WORD_LENGTH];
|
||||||
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
|
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
|
||||||
prevWordIds)) {
|
prevWordIds)) {
|
||||||
|
|
|
@ -48,6 +48,11 @@ class ForgettingCurveUtils {
|
||||||
static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
|
static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
|
||||||
const int bigramCount, const HeaderPolicy *const headerPolicy);
|
const int bigramCount, const HeaderPolicy *const headerPolicy);
|
||||||
|
|
||||||
|
// TODO: Improve probability computation method and remove this.
|
||||||
|
static int getProbabilityBiasForNgram(const int n) {
|
||||||
|
return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
|
||||||
|
}
|
||||||
|
|
||||||
AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
|
AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
|
||||||
return static_cast<int>(static_cast<float>(maxUnigramCount)
|
return static_cast<int>(static_cast<float>(maxUnigramCount)
|
||||||
* UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
|
* UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
|
||||||
|
|
|
@ -61,9 +61,9 @@ class IntArrayView {
|
||||||
return IntArrayView(array, N);
|
return IntArrayView(array, N);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a view that points one int object. Does not take ownership of the given object.
|
// Returns a view that points one int object.
|
||||||
AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) {
|
AK_FORCE_INLINE static IntArrayView singleElementView(const int *const ptr) {
|
||||||
return IntArrayView(object, 1);
|
return IntArrayView(ptr, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
AK_FORCE_INLINE int operator[](const size_t index) const {
|
AK_FORCE_INLINE int operator[](const size_t index) const {
|
||||||
|
|
|
@ -26,28 +26,28 @@ namespace latinime {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(LanguageModelDictContentTest, TestUnigramProbability) {
|
TEST(LanguageModelDictContentTest, TestUnigramProbability) {
|
||||||
LanguageModelDictContent LanguageModelDictContent(false /* useHistoricalInfo */);
|
LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
|
||||||
|
|
||||||
const int flag = 0xFF;
|
const int flag = 0xFF;
|
||||||
const int probability = 10;
|
const int probability = 10;
|
||||||
const int wordId = 100;
|
const int wordId = 100;
|
||||||
const ProbabilityEntry probabilityEntry(flag, probability);
|
const ProbabilityEntry probabilityEntry(flag, probability);
|
||||||
LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
||||||
const ProbabilityEntry entry =
|
const ProbabilityEntry entry =
|
||||||
LanguageModelDictContent.getProbabilityEntry(wordId);
|
languageModelDictContent.getProbabilityEntry(wordId);
|
||||||
EXPECT_EQ(flag, entry.getFlags());
|
EXPECT_EQ(flag, entry.getFlags());
|
||||||
EXPECT_EQ(probability, entry.getProbability());
|
EXPECT_EQ(probability, entry.getProbability());
|
||||||
|
|
||||||
// Remove
|
// Remove
|
||||||
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
|
EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
|
||||||
EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
|
EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
|
||||||
EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId));
|
EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
|
||||||
EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
|
EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
|
||||||
EXPECT_TRUE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
|
EXPECT_TRUE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
|
TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
|
||||||
LanguageModelDictContent LanguageModelDictContent(true /* useHistoricalInfo */);
|
LanguageModelDictContent languageModelDictContent(true /* useHistoricalInfo */);
|
||||||
|
|
||||||
const int flag = 0xF0;
|
const int flag = 0xF0;
|
||||||
const int timestamp = 0x3FFFFFFF;
|
const int timestamp = 0x3FFFFFFF;
|
||||||
|
@ -56,19 +56,19 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
|
||||||
const int wordId = 100;
|
const int wordId = 100;
|
||||||
const HistoricalInfo historicalInfo(timestamp, level, count);
|
const HistoricalInfo historicalInfo(timestamp, level, count);
|
||||||
const ProbabilityEntry probabilityEntry(flag, &historicalInfo);
|
const ProbabilityEntry probabilityEntry(flag, &historicalInfo);
|
||||||
LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
||||||
const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId);
|
const ProbabilityEntry entry = languageModelDictContent.getProbabilityEntry(wordId);
|
||||||
EXPECT_EQ(flag, entry.getFlags());
|
EXPECT_EQ(flag, entry.getFlags());
|
||||||
EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp());
|
EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp());
|
||||||
EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel());
|
EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel());
|
||||||
EXPECT_EQ(count, entry.getHistoricalInfo()->getCount());
|
EXPECT_EQ(count, entry.getHistoricalInfo()->getCount());
|
||||||
|
|
||||||
// Remove
|
// Remove
|
||||||
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
|
EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
|
||||||
EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
|
EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
|
||||||
EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId));
|
EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
|
||||||
EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
|
EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
|
||||||
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
|
EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
|
TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
|
||||||
|
@ -89,5 +89,31 @@ TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
|
||||||
EXPECT_TRUE(wordIdSet.empty());
|
EXPECT_TRUE(wordIdSet.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(LanguageModelDictContentTest, TestGetWordProbability) {
|
||||||
|
LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
|
||||||
|
|
||||||
|
const int flag = 0xFF;
|
||||||
|
const int probability = 10;
|
||||||
|
const int bigramProbability = 20;
|
||||||
|
const int trigramProbability = 30;
|
||||||
|
const int wordId = 100;
|
||||||
|
const int prevWordIdArray[] = { 1, 2 };
|
||||||
|
const WordIdArrayView prevWordIds = WordIdArrayView::fromFixedSizeArray(prevWordIdArray);
|
||||||
|
|
||||||
|
const ProbabilityEntry probabilityEntry(flag, probability);
|
||||||
|
languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
||||||
|
const ProbabilityEntry bigramProbabilityEntry(flag, bigramProbability);
|
||||||
|
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
|
||||||
|
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
|
||||||
|
&bigramProbabilityEntry);
|
||||||
|
EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
|
||||||
|
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
|
||||||
|
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
|
||||||
|
prevWordIds[1], &probabilityEntry);
|
||||||
|
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
|
||||||
|
&trigramProbabilityEntry);
|
||||||
|
EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
|
|
|
@ -52,7 +52,7 @@ TEST(IntArrayViewTest, TestConstructFromArray) {
|
||||||
|
|
||||||
TEST(IntArrayViewTest, TestConstructFromObject) {
|
TEST(IntArrayViewTest, TestConstructFromObject) {
|
||||||
const int object = 10;
|
const int object = 10;
|
||||||
const auto intArrayView = IntArrayView::fromObject(&object);
|
const auto intArrayView = IntArrayView::singleElementView(&object);
|
||||||
EXPECT_EQ(1u, intArrayView.size());
|
EXPECT_EQ(1u, intArrayView.size());
|
||||||
EXPECT_EQ(object, intArrayView[0]);
|
EXPECT_EQ(object, intArrayView[0]);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue