am 47ae7368: Merge "Add bigrams to language model content."

* commit '47ae73685aab3e758978e1fdc53c3b6e2d5e3a43':
  Add bigrams to language model content.
main
Keisuke Kuroyanagi 2014-08-12 12:53:07 +00:00 committed by Android Git Automerger
commit a5e6c10824
10 changed files with 96 additions and 22 deletions

View File

@ -234,8 +234,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) { const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) {
if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) { if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) {
AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId()); prevWordIds[0], wordId);
return false; return false;
} }
const int ptNodePos = const int ptNodePos =

View File

@ -46,7 +46,7 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds, bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds,
const int terminalId, const ProbabilityEntry *const probabilityEntry) { const int terminalId, const ProbabilityEntry *const probabilityEntry) {
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds);
if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
return false; return false;
} }
@ -80,6 +80,19 @@ bool LanguageModelDictContent::runGCInner(
return true; return true;
} }
int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
if (prevWordIds.empty()) {
return mTrieMap.getRootBitmapEntryIndex();
}
const int lastBitmapEntryIndex =
getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
return TrieMap::INVALID_INDEX;
}
return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1],
lastBitmapEntryIndex);
}
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
for (const int wordId : prevWordIds) { for (const int wordId : prevWordIds) {

View File

@ -76,7 +76,7 @@ class LanguageModelDictContent {
bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex, const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
int *const outNgramCount); int *const outNgramCount);
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
}; };
} // namespace latinime } // namespace latinime

View File

@ -21,6 +21,8 @@
#include <cstdint> #include <cstdint>
#include "defines.h" #include "defines.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/historical_info.h" #include "suggest/policyimpl/dictionary/utils/historical_info.h"
@ -45,6 +47,20 @@ class ProbabilityEntry {
const HistoricalInfo *const historicalInfo) const HistoricalInfo *const historicalInfo)
: mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {} : mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}
// Create from unigram property.
// TODO: Set flags.
ProbabilityEntry(const UnigramProperty *const unigramProperty)
: mFlags(0), mProbability(unigramProperty->getProbability()),
mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
unigramProperty->getCount()) {}
// Create from bigram property.
// TODO: Set flags.
ProbabilityEntry(const BigramProperty *const bigramProperty)
: mFlags(0), mProbability(bigramProperty->getProbability()),
mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
bigramProperty->getCount()) {}
const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const { const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
return ProbabilityEntry(mFlags, probability, &mHistoricalInfo); return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
} }
@ -54,6 +70,10 @@ class ProbabilityEntry {
return ProbabilityEntry(mFlags, mProbability, historicalInfo); return ProbabilityEntry(mFlags, mProbability, historicalInfo);
} }
bool isValid() const {
return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
}
bool hasHistoricalInfo() const { bool hasHistoricalInfo() const {
return mHistoricalInfo.isValid(); return mHistoricalInfo.isValid();
} }
@ -89,7 +109,7 @@ class ProbabilityEntry {
static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) { static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
if (hasHistoricalInfo) { if (hasHistoricalInfo) {
const int flags = readFromEncodedEntry(encodedEntry, const int flags = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
Ver4DictConstants::TIME_STAMP_FIELD_SIZE Ver4DictConstants::TIME_STAMP_FIELD_SIZE
+ Ver4DictConstants::WORD_LEVEL_FIELD_SIZE + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
+ Ver4DictConstants::WORD_COUNT_FIELD_SIZE); + Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
@ -106,7 +126,7 @@ class ProbabilityEntry {
return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo); return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
} else { } else {
const int flags = readFromEncodedEntry(encodedEntry, const int flags = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
Ver4DictConstants::PROBABILITY_SIZE); Ver4DictConstants::PROBABILITY_SIZE);
const int probability = readFromEncodedEntry(encodedEntry, const int probability = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */); Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);

View File

@ -46,7 +46,7 @@ const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX =
const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1; const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1;
const int Ver4DictConstants::PROBABILITY_SIZE = 1; const int Ver4DictConstants::PROBABILITY_SIZE = 1;
const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1; const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1;
const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3; const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3;
const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0; const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0;
const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4; const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4;

View File

@ -41,7 +41,7 @@ class Ver4DictConstants {
static const int NOT_A_TERMINAL_ID; static const int NOT_A_TERMINAL_ID;
static const int PROBABILITY_SIZE; static const int PROBABILITY_SIZE;
static const int FLAGS_IN_PROBABILITY_FILE_SIZE; static const int FLAGS_IN_LANGUAGE_MODEL_SIZE;
static const int TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE; static const int TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE;
static const int NOT_A_TERMINAL_ADDRESS; static const int NOT_A_TERMINAL_ADDRESS;
static const int TERMINAL_ID_FIELD_SIZE; static const int TERMINAL_ID_FIELD_SIZE;

View File

@ -145,10 +145,11 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
const ProbabilityEntry originalProbabilityEntry = const ProbabilityEntry originalProbabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry( mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId()); toBeUpdatedPtNodeParams->getTerminalId());
const ProbabilityEntry probabilityEntry = createUpdatedEntryFrom(&originalProbabilityEntry, const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
unigramProperty); const ProbabilityEntry updatedProbabilityEntry =
createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry); toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
} }
bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC( bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
@ -216,16 +217,36 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
} }
// Write probability. // Write probability.
ProbabilityEntry newProbabilityEntry; ProbabilityEntry newProbabilityEntry;
const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom( const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
&newProbabilityEntry, unigramProperty); &newProbabilityEntry, &probabilityEntryOfUnigramProperty);
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
terminalId, &probabilityEntryToWrite); terminalId, &probabilityEntryToWrite);
} }
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) {
// TODO: Support n-gram.
LanguageModelDictContent *const languageModelDictContent =
mBuffers->getMutableLanguageModelDictContent();
const ProbabilityEntry probabilityEntry =
languageModelDictContent->getNgramProbabilityEntry(
prevWordIds.limit(1 /* maxSize */), wordId);
const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty);
const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
&probabilityEntry, &probabilityEntryOfBigramProperty);
if (!languageModelDictContent->setNgramProbabilityEntry(
prevWordIds.limit(1 /* maxSize */), wordId, &updatedProbabilityEntry)) {
AKLOGE("Cannot add new ngram entry. prevWordId: %d, wordId: %d",
prevWordIds[0], wordId);
return false;
}
if (!probabilityEntry.isValid() && outAddedNewBigram) {
*outAddedNewBigram = true;
}
// TODO: Remove.
if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) { if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) {
AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
prevWordIds[0], wordId); prevWordIds[0], wordId);
return false; return false;
} }
@ -234,6 +255,7 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds, bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds,
const int wordId) { const int wordId) {
// TODO: Remove.
return mBigramPolicy->removeEntry(prevWordIds[0], wordId); return mBigramPolicy->removeEntry(prevWordIds[0], wordId);
} }
@ -352,20 +374,19 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry, const ProbabilityEntry *const originalProbabilityEntry,
const UnigramProperty *const unigramProperty) const { const ProbabilityEntry *const probabilityEntry) const {
// TODO: Consolidate historical info and probability. // TODO: Consolidate historical info and probability.
if (mHeaderPolicy->hasHistoricalInfoOfWords()) { if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(),
unigramProperty->getLevel(), unigramProperty->getCount());
const HistoricalInfo updatedHistoricalInfo = const HistoricalInfo updatedHistoricalInfo =
ForgettingCurveUtils::createUpdatedHistoricalInfo( ForgettingCurveUtils::createUpdatedHistoricalInfo(
originalProbabilityEntry->getHistoricalInfo(), originalProbabilityEntry->getHistoricalInfo(),
unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy); probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
mHeaderPolicy);
return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo( return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo(
&updatedHistoricalInfo); &updatedHistoricalInfo);
} else { } else {
return originalProbabilityEntry->createEntryWithUpdatedProbability( return originalProbabilityEntry->createEntryWithUpdatedProbability(
unigramProperty->getProbability()); probabilityEntry->getProbability());
} }
} }

View File

@ -98,12 +98,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
const PtNodeParams *const ptNodeParams, int *const outTerminalId, const PtNodeParams *const ptNodeParams, int *const outTerminalId,
int *const ptNodeWritingPos); int *const ptNodeWritingPos);
// Create updated probability entry using given unigram property. In addition to the // Create updated probability entry using given probability property. In addition to the
// probability, this method updates historical information if needed. // probability, this method updates historical information if needed.
// TODO: Update flags belonging to the unigram property. // TODO: Update flags.
const ProbabilityEntry createUpdatedEntryFrom( const ProbabilityEntry createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry, const ProbabilityEntry *const originalProbabilityEntry,
const UnigramProperty *const unigramProperty) const; const ProbabilityEntry *const probabilityEntry) const;
bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord, bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord,
const bool isTerminal, const bool hasMultipleChars); const bool isTerminal, const bool hasMultipleChars);

View File

@ -91,6 +91,11 @@ class IntArrayView {
return mPtr + mSize; return mPtr + mSize;
} }
// Returns the view whose size is smaller than or equal to the given count.
const IntArrayView limit(const size_t maxSize) const {
return IntArrayView(mPtr, std::min(maxSize, mSize));
}
private: private:
DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView); DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView);

View File

@ -53,9 +53,24 @@ TEST(IntArrayViewTest, TestConstructFromArray) {
TEST(IntArrayViewTest, TestConstructFromObject) { TEST(IntArrayViewTest, TestConstructFromObject) {
const int object = 10; const int object = 10;
const auto intArrayView = IntArrayView::fromObject(&object); const auto intArrayView = IntArrayView::fromObject(&object);
EXPECT_EQ(1, intArrayView.size()); EXPECT_EQ(1u, intArrayView.size());
EXPECT_EQ(object, intArrayView[0]); EXPECT_EQ(object, intArrayView[0]);
} }
TEST(IntArrayViewTest, TestLimit) {
const std::vector<int> intVector = {3, 2, 1, 0, -1, -2};
IntArrayView intArrayView(intVector);
EXPECT_TRUE(intArrayView.limit(0).empty());
EXPECT_EQ(intArrayView.size(), intArrayView.limit(intArrayView.size()).size());
EXPECT_EQ(intArrayView.size(), intArrayView.limit(1000).size());
IntArrayView subView = intArrayView.limit(4);
EXPECT_EQ(4u, subView.size());
for (size_t i = 0; i < subView.size(); ++i) {
EXPECT_EQ(intVector[i], subView[i]);
}
}
} // namespace } // namespace
} // namespace latinime } // namespace latinime