Merge "Add bigrams to language model content."
This commit is contained in:
commit
47ae73685a
10 changed files with 96 additions and 22 deletions
|
@ -234,8 +234,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
|
|||
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
|
||||
const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) {
|
||||
if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) {
|
||||
AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
|
||||
sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId());
|
||||
AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
|
||||
prevWordIds[0], wordId);
|
||||
return false;
|
||||
}
|
||||
const int ptNodePos =
|
||||
|
|
|
@ -46,7 +46,7 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
|
|||
|
||||
bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds,
|
||||
const int terminalId, const ProbabilityEntry *const probabilityEntry) {
|
||||
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
|
||||
const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds);
|
||||
if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
||||
return false;
|
||||
}
|
||||
|
@ -80,6 +80,19 @@ bool LanguageModelDictContent::runGCInner(
|
|||
return true;
|
||||
}
|
||||
|
||||
int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
|
||||
if (prevWordIds.empty()) {
|
||||
return mTrieMap.getRootBitmapEntryIndex();
|
||||
}
|
||||
const int lastBitmapEntryIndex =
|
||||
getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
|
||||
if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
||||
return TrieMap::INVALID_INDEX;
|
||||
}
|
||||
return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1],
|
||||
lastBitmapEntryIndex);
|
||||
}
|
||||
|
||||
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
|
||||
int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
|
||||
for (const int wordId : prevWordIds) {
|
||||
|
|
|
@ -76,7 +76,7 @@ class LanguageModelDictContent {
|
|||
bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
|
||||
const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
|
||||
int *const outNgramCount);
|
||||
|
||||
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
|
||||
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
|
||||
};
|
||||
} // namespace latinime
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include <cstdint>
|
||||
|
||||
#include "defines.h"
|
||||
#include "suggest/core/dictionary/property/bigram_property.h"
|
||||
#include "suggest/core/dictionary/property/unigram_property.h"
|
||||
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
|
||||
#include "suggest/policyimpl/dictionary/utils/historical_info.h"
|
||||
|
||||
|
@ -45,6 +47,20 @@ class ProbabilityEntry {
|
|||
const HistoricalInfo *const historicalInfo)
|
||||
: mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}
|
||||
|
||||
// Create from unigram property.
|
||||
// TODO: Set flags.
|
||||
ProbabilityEntry(const UnigramProperty *const unigramProperty)
|
||||
: mFlags(0), mProbability(unigramProperty->getProbability()),
|
||||
mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
|
||||
unigramProperty->getCount()) {}
|
||||
|
||||
// Create from bigram property.
|
||||
// TODO: Set flags.
|
||||
ProbabilityEntry(const BigramProperty *const bigramProperty)
|
||||
: mFlags(0), mProbability(bigramProperty->getProbability()),
|
||||
mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
|
||||
bigramProperty->getCount()) {}
|
||||
|
||||
const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
|
||||
return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
|
||||
}
|
||||
|
@ -54,6 +70,10 @@ class ProbabilityEntry {
|
|||
return ProbabilityEntry(mFlags, mProbability, historicalInfo);
|
||||
}
|
||||
|
||||
bool isValid() const {
|
||||
return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
|
||||
}
|
||||
|
||||
bool hasHistoricalInfo() const {
|
||||
return mHistoricalInfo.isValid();
|
||||
}
|
||||
|
@ -89,7 +109,7 @@ class ProbabilityEntry {
|
|||
static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
|
||||
if (hasHistoricalInfo) {
|
||||
const int flags = readFromEncodedEntry(encodedEntry,
|
||||
Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
|
||||
Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
|
||||
Ver4DictConstants::TIME_STAMP_FIELD_SIZE
|
||||
+ Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
|
||||
+ Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
|
||||
|
@ -106,7 +126,7 @@ class ProbabilityEntry {
|
|||
return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
|
||||
} else {
|
||||
const int flags = readFromEncodedEntry(encodedEntry,
|
||||
Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
|
||||
Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
|
||||
Ver4DictConstants::PROBABILITY_SIZE);
|
||||
const int probability = readFromEncodedEntry(encodedEntry,
|
||||
Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);
|
||||
|
|
|
@ -46,7 +46,7 @@ const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX =
|
|||
|
||||
const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1;
|
||||
const int Ver4DictConstants::PROBABILITY_SIZE = 1;
|
||||
const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1;
|
||||
const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1;
|
||||
const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3;
|
||||
const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0;
|
||||
const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4;
|
||||
|
|
|
@ -41,7 +41,7 @@ class Ver4DictConstants {
|
|||
|
||||
static const int NOT_A_TERMINAL_ID;
|
||||
static const int PROBABILITY_SIZE;
|
||||
static const int FLAGS_IN_PROBABILITY_FILE_SIZE;
|
||||
static const int FLAGS_IN_LANGUAGE_MODEL_SIZE;
|
||||
static const int TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE;
|
||||
static const int NOT_A_TERMINAL_ADDRESS;
|
||||
static const int TERMINAL_ID_FIELD_SIZE;
|
||||
|
|
|
@ -145,10 +145,11 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
|
|||
const ProbabilityEntry originalProbabilityEntry =
|
||||
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
|
||||
toBeUpdatedPtNodeParams->getTerminalId());
|
||||
const ProbabilityEntry probabilityEntry = createUpdatedEntryFrom(&originalProbabilityEntry,
|
||||
unigramProperty);
|
||||
const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
|
||||
const ProbabilityEntry updatedProbabilityEntry =
|
||||
createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
|
||||
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
||||
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry);
|
||||
toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
|
||||
}
|
||||
|
||||
bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
|
||||
|
@ -216,16 +217,36 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
|
|||
}
|
||||
// Write probability.
|
||||
ProbabilityEntry newProbabilityEntry;
|
||||
const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
|
||||
const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
|
||||
&newProbabilityEntry, unigramProperty);
|
||||
&newProbabilityEntry, &probabilityEntryOfUnigramProperty);
|
||||
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
||||
terminalId, &probabilityEntryToWrite);
|
||||
}
|
||||
|
||||
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
|
||||
const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) {
|
||||
// TODO: Support n-gram.
|
||||
LanguageModelDictContent *const languageModelDictContent =
|
||||
mBuffers->getMutableLanguageModelDictContent();
|
||||
const ProbabilityEntry probabilityEntry =
|
||||
languageModelDictContent->getNgramProbabilityEntry(
|
||||
prevWordIds.limit(1 /* maxSize */), wordId);
|
||||
const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty);
|
||||
const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
|
||||
&probabilityEntry, &probabilityEntryOfBigramProperty);
|
||||
if (!languageModelDictContent->setNgramProbabilityEntry(
|
||||
prevWordIds.limit(1 /* maxSize */), wordId, &updatedProbabilityEntry)) {
|
||||
AKLOGE("Cannot add new ngram entry. prevWordId: %d, wordId: %d",
|
||||
prevWordIds[0], wordId);
|
||||
return false;
|
||||
}
|
||||
if (!probabilityEntry.isValid() && outAddedNewBigram) {
|
||||
*outAddedNewBigram = true;
|
||||
}
|
||||
// TODO: Remove.
|
||||
if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) {
|
||||
AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
|
||||
AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
|
||||
prevWordIds[0], wordId);
|
||||
return false;
|
||||
}
|
||||
|
@ -234,6 +255,7 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
|
|||
|
||||
bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds,
|
||||
const int wordId) {
|
||||
// TODO: Remove.
|
||||
return mBigramPolicy->removeEntry(prevWordIds[0], wordId);
|
||||
}
|
||||
|
||||
|
@ -352,20 +374,19 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
|
|||
|
||||
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
|
||||
const ProbabilityEntry *const originalProbabilityEntry,
|
||||
const UnigramProperty *const unigramProperty) const {
|
||||
const ProbabilityEntry *const probabilityEntry) const {
|
||||
// TODO: Consolidate historical info and probability.
|
||||
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
|
||||
const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(),
|
||||
unigramProperty->getLevel(), unigramProperty->getCount());
|
||||
const HistoricalInfo updatedHistoricalInfo =
|
||||
ForgettingCurveUtils::createUpdatedHistoricalInfo(
|
||||
originalProbabilityEntry->getHistoricalInfo(),
|
||||
unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy);
|
||||
probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
|
||||
mHeaderPolicy);
|
||||
return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo(
|
||||
&updatedHistoricalInfo);
|
||||
} else {
|
||||
return originalProbabilityEntry->createEntryWithUpdatedProbability(
|
||||
unigramProperty->getProbability());
|
||||
probabilityEntry->getProbability());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -98,12 +98,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
|
|||
const PtNodeParams *const ptNodeParams, int *const outTerminalId,
|
||||
int *const ptNodeWritingPos);
|
||||
|
||||
// Create updated probability entry using given unigram property. In addition to the
|
||||
// Create updated probability entry using given probability property. In addition to the
|
||||
// probability, this method updates historical information if needed.
|
||||
// TODO: Update flags belonging to the unigram property.
|
||||
// TODO: Update flags.
|
||||
const ProbabilityEntry createUpdatedEntryFrom(
|
||||
const ProbabilityEntry *const originalProbabilityEntry,
|
||||
const UnigramProperty *const unigramProperty) const;
|
||||
const ProbabilityEntry *const probabilityEntry) const;
|
||||
|
||||
bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord,
|
||||
const bool isTerminal, const bool hasMultipleChars);
|
||||
|
|
|
@ -91,6 +91,11 @@ class IntArrayView {
|
|||
return mPtr + mSize;
|
||||
}
|
||||
|
||||
// Returns the view whose size is smaller than or equal to the given count.
|
||||
const IntArrayView limit(const size_t maxSize) const {
|
||||
return IntArrayView(mPtr, std::min(maxSize, mSize));
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView);
|
||||
|
||||
|
|
|
@ -53,9 +53,24 @@ TEST(IntArrayViewTest, TestConstructFromArray) {
|
|||
TEST(IntArrayViewTest, TestConstructFromObject) {
|
||||
const int object = 10;
|
||||
const auto intArrayView = IntArrayView::fromObject(&object);
|
||||
EXPECT_EQ(1, intArrayView.size());
|
||||
EXPECT_EQ(1u, intArrayView.size());
|
||||
EXPECT_EQ(object, intArrayView[0]);
|
||||
}
|
||||
|
||||
TEST(IntArrayViewTest, TestLimit) {
|
||||
const std::vector<int> intVector = {3, 2, 1, 0, -1, -2};
|
||||
IntArrayView intArrayView(intVector);
|
||||
|
||||
EXPECT_TRUE(intArrayView.limit(0).empty());
|
||||
EXPECT_EQ(intArrayView.size(), intArrayView.limit(intArrayView.size()).size());
|
||||
EXPECT_EQ(intArrayView.size(), intArrayView.limit(1000).size());
|
||||
|
||||
IntArrayView subView = intArrayView.limit(4);
|
||||
EXPECT_EQ(4u, subView.size());
|
||||
for (size_t i = 0; i < subView.size(); ++i) {
|
||||
EXPECT_EQ(intVector[i], subView[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace latinime
|
||||
|
|
Loading…
Reference in a new issue