Merge "Add BoS flag in probability entry."

This commit is contained in:
Keisuke Kuroyanagi 2014-08-19 03:02:23 +00:00 committed by Android (Google) Code Review
commit ace03d7919
6 changed files with 30 additions and 26 deletions

View file

@ -43,14 +43,13 @@ class ProbabilityEntry {
: mFlags(flags), mProbability(probability), mHistoricalInfo() {} : mFlags(flags), mProbability(probability), mHistoricalInfo() {}
// Entry with historical information. // Entry with historical information.
ProbabilityEntry(const int flags, const int probability, ProbabilityEntry(const int flags, const HistoricalInfo *const historicalInfo)
const HistoricalInfo *const historicalInfo) : mFlags(flags), mProbability(NOT_A_PROBABILITY), mHistoricalInfo(*historicalInfo) {}
: mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}
// Create from unigram property. // Create from unigram property.
// TODO: Set flags.
ProbabilityEntry(const UnigramProperty *const unigramProperty) ProbabilityEntry(const UnigramProperty *const unigramProperty)
: mFlags(0), mProbability(unigramProperty->getProbability()), : mFlags(createFlags(unigramProperty->representsBeginningOfSentence())),
mProbability(unigramProperty->getProbability()),
mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(), mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
unigramProperty->getCount()) {} unigramProperty->getCount()) {}
@ -61,15 +60,6 @@ class ProbabilityEntry {
mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(), mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
bigramProperty->getCount()) {} bigramProperty->getCount()) {}
const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
}
const ProbabilityEntry createEntryWithUpdatedHistoricalInfo(
const HistoricalInfo *const historicalInfo) const {
return ProbabilityEntry(mFlags, mProbability, historicalInfo);
}
bool isValid() const { bool isValid() const {
return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo(); return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
} }
@ -78,7 +68,7 @@ class ProbabilityEntry {
return mHistoricalInfo.isValid(); return mHistoricalInfo.isValid();
} }
int getFlags() const { uint8_t getFlags() const {
return mFlags; return mFlags;
} }
@ -90,6 +80,10 @@ class ProbabilityEntry {
return &mHistoricalInfo; return &mHistoricalInfo;
} }
bool representsBeginningOfSentence() const {
return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0;
}
uint64_t encode(const bool hasHistoricalInfo) const { uint64_t encode(const bool hasHistoricalInfo) const {
uint64_t encodedEntry = static_cast<uint64_t>(mFlags); uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
if (hasHistoricalInfo) { if (hasHistoricalInfo) {
@ -123,7 +117,7 @@ class ProbabilityEntry {
const int count = readFromEncodedEntry(encodedEntry, const int count = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */); Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */);
const HistoricalInfo historicalInfo(timestamp, level, count); const HistoricalInfo historicalInfo(timestamp, level, count);
return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo); return ProbabilityEntry(flags, &historicalInfo);
} else { } else {
const int flags = readFromEncodedEntry(encodedEntry, const int flags = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE, Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
@ -138,7 +132,7 @@ class ProbabilityEntry {
// Copy constructor is public to use this class as a type of return value. // Copy constructor is public to use this class as a type of return value.
DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry); DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry);
const int mFlags; const uint8_t mFlags;
const int mProbability; const int mProbability;
const HistoricalInfo mHistoricalInfo; const HistoricalInfo mHistoricalInfo;
@ -146,6 +140,14 @@ class ProbabilityEntry {
return static_cast<int>( return static_cast<int>(
(encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1)); (encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
} }
static uint8_t createFlags(const bool representsBeginningOfSentence) {
uint8_t flags = 0;
if (representsBeginningOfSentence) {
flags ^= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
}
return flags;
}
}; };
} // namespace latinime } // namespace latinime
#endif /* LATINIME_PROBABILITY_ENTRY_H */ #endif /* LATINIME_PROBABILITY_ENTRY_H */

View file

@ -54,6 +54,8 @@ const int Ver4DictConstants::TIME_STAMP_FIELD_SIZE = 4;
const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1; const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1;
const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1; const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1;
const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1;
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE = 16; const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE = 16;
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE = 4; const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE = 4;
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64;

View file

@ -20,6 +20,7 @@
#include "defines.h" #include "defines.h"
#include <cstddef> #include <cstddef>
#include <cstdint>
namespace latinime { namespace latinime {
@ -48,6 +49,8 @@ class Ver4DictConstants {
static const int TIME_STAMP_FIELD_SIZE; static const int TIME_STAMP_FIELD_SIZE;
static const int WORD_LEVEL_FIELD_SIZE; static const int WORD_LEVEL_FIELD_SIZE;
static const int WORD_COUNT_FIELD_SIZE; static const int WORD_COUNT_FIELD_SIZE;
// Flags in probability entry.
static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
static const int BIGRAM_ADDRESS_TABLE_BLOCK_SIZE; static const int BIGRAM_ADDRESS_TABLE_BLOCK_SIZE;
static const int BIGRAM_ADDRESS_TABLE_DATA_SIZE; static const int BIGRAM_ADDRESS_TABLE_DATA_SIZE;

View file

@ -164,8 +164,8 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
if (originalProbabilityEntry.hasHistoricalInfo()) { if (originalProbabilityEntry.hasHistoricalInfo()) {
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy); originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
const ProbabilityEntry probabilityEntry = const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
originalProbabilityEntry.createEntryWithUpdatedHistoricalInfo(&historicalInfo); &historicalInfo);
if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) { toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
AKLOGE("Cannot write updated probability entry. terminalId: %d", AKLOGE("Cannot write updated probability entry. terminalId: %d",
@ -383,18 +383,15 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry, const ProbabilityEntry *const originalProbabilityEntry,
const ProbabilityEntry *const probabilityEntry) const { const ProbabilityEntry *const probabilityEntry) const {
// TODO: Consolidate historical info and probability.
if (mHeaderPolicy->hasHistoricalInfoOfWords()) { if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
const HistoricalInfo updatedHistoricalInfo = const HistoricalInfo updatedHistoricalInfo =
ForgettingCurveUtils::createUpdatedHistoricalInfo( ForgettingCurveUtils::createUpdatedHistoricalInfo(
originalProbabilityEntry->getHistoricalInfo(), originalProbabilityEntry->getHistoricalInfo(),
probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(), probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
mHeaderPolicy); mHeaderPolicy);
return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo( return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo);
&updatedHistoricalInfo);
} else { } else {
return originalProbabilityEntry->createEntryWithUpdatedProbability( return *probabilityEntry;
probabilityEntry->getProbability());
} }
} }

View file

@ -53,7 +53,7 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
const int count = 10; const int count = 10;
const int wordId = 100; const int wordId = 100;
const HistoricalInfo historicalInfo(timestamp, level, count); const HistoricalInfo historicalInfo(timestamp, level, count);
const ProbabilityEntry probabilityEntry(flag, NOT_A_PROBABILITY, &historicalInfo); const ProbabilityEntry probabilityEntry(flag, &historicalInfo);
LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId); const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId);
EXPECT_EQ(flag, entry.getFlags()); EXPECT_EQ(flag, entry.getFlags());

View file

@ -43,7 +43,7 @@ TEST(ProbabilityEntryTest, TestEncodeDecodeWithHistoricalInfo) {
const int count = 10; const int count = 10;
const HistoricalInfo historicalInfo(timestamp, level, count); const HistoricalInfo historicalInfo(timestamp, level, count);
const ProbabilityEntry entry(flag, NOT_A_PROBABILITY, &historicalInfo); const ProbabilityEntry entry(flag, &historicalInfo);
const uint64_t encodedEntry = entry.encode(true /* hasHistoricalInfo */); const uint64_t encodedEntry = entry.encode(true /* hasHistoricalInfo */);
EXPECT_EQ(0xF03FFFFFFF030Aull, encodedEntry); EXPECT_EQ(0xF03FFFFFFF030Aull, encodedEntry);