Add BoS flag in probability entry.
Bug: 14425059 Change-Id: I50439630034ada0280c44cbbb308aa0b95b72048
This commit is contained in:
parent
2f34f0d1a8
commit
623067a183
6 changed files with 30 additions and 26 deletions
|
@ -43,14 +43,13 @@ class ProbabilityEntry {
|
||||||
: mFlags(flags), mProbability(probability), mHistoricalInfo() {}
|
: mFlags(flags), mProbability(probability), mHistoricalInfo() {}
|
||||||
|
|
||||||
// Entry with historical information.
|
// Entry with historical information.
|
||||||
ProbabilityEntry(const int flags, const int probability,
|
ProbabilityEntry(const int flags, const HistoricalInfo *const historicalInfo)
|
||||||
const HistoricalInfo *const historicalInfo)
|
: mFlags(flags), mProbability(NOT_A_PROBABILITY), mHistoricalInfo(*historicalInfo) {}
|
||||||
: mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}
|
|
||||||
|
|
||||||
// Create from unigram property.
|
// Create from unigram property.
|
||||||
// TODO: Set flags.
|
|
||||||
ProbabilityEntry(const UnigramProperty *const unigramProperty)
|
ProbabilityEntry(const UnigramProperty *const unigramProperty)
|
||||||
: mFlags(0), mProbability(unigramProperty->getProbability()),
|
: mFlags(createFlags(unigramProperty->representsBeginningOfSentence())),
|
||||||
|
mProbability(unigramProperty->getProbability()),
|
||||||
mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
|
mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
|
||||||
unigramProperty->getCount()) {}
|
unigramProperty->getCount()) {}
|
||||||
|
|
||||||
|
@ -61,15 +60,6 @@ class ProbabilityEntry {
|
||||||
mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
|
mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
|
||||||
bigramProperty->getCount()) {}
|
bigramProperty->getCount()) {}
|
||||||
|
|
||||||
const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
|
|
||||||
return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
const ProbabilityEntry createEntryWithUpdatedHistoricalInfo(
|
|
||||||
const HistoricalInfo *const historicalInfo) const {
|
|
||||||
return ProbabilityEntry(mFlags, mProbability, historicalInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isValid() const {
|
bool isValid() const {
|
||||||
return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
|
return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
|
||||||
}
|
}
|
||||||
|
@ -78,7 +68,7 @@ class ProbabilityEntry {
|
||||||
return mHistoricalInfo.isValid();
|
return mHistoricalInfo.isValid();
|
||||||
}
|
}
|
||||||
|
|
||||||
int getFlags() const {
|
uint8_t getFlags() const {
|
||||||
return mFlags;
|
return mFlags;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,6 +80,10 @@ class ProbabilityEntry {
|
||||||
return &mHistoricalInfo;
|
return &mHistoricalInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool representsBeginningOfSentence() const {
|
||||||
|
return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
uint64_t encode(const bool hasHistoricalInfo) const {
|
uint64_t encode(const bool hasHistoricalInfo) const {
|
||||||
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
|
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
|
||||||
if (hasHistoricalInfo) {
|
if (hasHistoricalInfo) {
|
||||||
|
@ -123,7 +117,7 @@ class ProbabilityEntry {
|
||||||
const int count = readFromEncodedEntry(encodedEntry,
|
const int count = readFromEncodedEntry(encodedEntry,
|
||||||
Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */);
|
Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */);
|
||||||
const HistoricalInfo historicalInfo(timestamp, level, count);
|
const HistoricalInfo historicalInfo(timestamp, level, count);
|
||||||
return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
|
return ProbabilityEntry(flags, &historicalInfo);
|
||||||
} else {
|
} else {
|
||||||
const int flags = readFromEncodedEntry(encodedEntry,
|
const int flags = readFromEncodedEntry(encodedEntry,
|
||||||
Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
|
Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
|
||||||
|
@ -138,7 +132,7 @@ class ProbabilityEntry {
|
||||||
// Copy constructor is public to use this class as a type of return value.
|
// Copy constructor is public to use this class as a type of return value.
|
||||||
DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry);
|
DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry);
|
||||||
|
|
||||||
const int mFlags;
|
const uint8_t mFlags;
|
||||||
const int mProbability;
|
const int mProbability;
|
||||||
const HistoricalInfo mHistoricalInfo;
|
const HistoricalInfo mHistoricalInfo;
|
||||||
|
|
||||||
|
@ -146,6 +140,14 @@ class ProbabilityEntry {
|
||||||
return static_cast<int>(
|
return static_cast<int>(
|
||||||
(encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
|
(encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static uint8_t createFlags(const bool representsBeginningOfSentence) {
|
||||||
|
uint8_t flags = 0;
|
||||||
|
if (representsBeginningOfSentence) {
|
||||||
|
flags ^= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
|
||||||
|
}
|
||||||
|
return flags;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
#endif /* LATINIME_PROBABILITY_ENTRY_H */
|
#endif /* LATINIME_PROBABILITY_ENTRY_H */
|
||||||
|
|
|
@ -54,6 +54,8 @@ const int Ver4DictConstants::TIME_STAMP_FIELD_SIZE = 4;
|
||||||
const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1;
|
const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1;
|
||||||
const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1;
|
const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1;
|
||||||
|
|
||||||
|
const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1;
|
||||||
|
|
||||||
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE = 16;
|
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE = 16;
|
||||||
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE = 4;
|
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE = 4;
|
||||||
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64;
|
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64;
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "defines.h"
|
#include "defines.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
|
@ -48,6 +49,8 @@ class Ver4DictConstants {
|
||||||
static const int TIME_STAMP_FIELD_SIZE;
|
static const int TIME_STAMP_FIELD_SIZE;
|
||||||
static const int WORD_LEVEL_FIELD_SIZE;
|
static const int WORD_LEVEL_FIELD_SIZE;
|
||||||
static const int WORD_COUNT_FIELD_SIZE;
|
static const int WORD_COUNT_FIELD_SIZE;
|
||||||
|
// Flags in probability entry.
|
||||||
|
static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
|
||||||
|
|
||||||
static const int BIGRAM_ADDRESS_TABLE_BLOCK_SIZE;
|
static const int BIGRAM_ADDRESS_TABLE_BLOCK_SIZE;
|
||||||
static const int BIGRAM_ADDRESS_TABLE_DATA_SIZE;
|
static const int BIGRAM_ADDRESS_TABLE_DATA_SIZE;
|
||||||
|
|
|
@ -164,8 +164,8 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
|
||||||
if (originalProbabilityEntry.hasHistoricalInfo()) {
|
if (originalProbabilityEntry.hasHistoricalInfo()) {
|
||||||
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
|
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
|
||||||
originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
|
originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
|
||||||
const ProbabilityEntry probabilityEntry =
|
const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
|
||||||
originalProbabilityEntry.createEntryWithUpdatedHistoricalInfo(&historicalInfo);
|
&historicalInfo);
|
||||||
if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
||||||
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
|
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
|
||||||
AKLOGE("Cannot write updated probability entry. terminalId: %d",
|
AKLOGE("Cannot write updated probability entry. terminalId: %d",
|
||||||
|
@ -383,18 +383,15 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
|
||||||
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
|
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
|
||||||
const ProbabilityEntry *const originalProbabilityEntry,
|
const ProbabilityEntry *const originalProbabilityEntry,
|
||||||
const ProbabilityEntry *const probabilityEntry) const {
|
const ProbabilityEntry *const probabilityEntry) const {
|
||||||
// TODO: Consolidate historical info and probability.
|
|
||||||
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
|
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
|
||||||
const HistoricalInfo updatedHistoricalInfo =
|
const HistoricalInfo updatedHistoricalInfo =
|
||||||
ForgettingCurveUtils::createUpdatedHistoricalInfo(
|
ForgettingCurveUtils::createUpdatedHistoricalInfo(
|
||||||
originalProbabilityEntry->getHistoricalInfo(),
|
originalProbabilityEntry->getHistoricalInfo(),
|
||||||
probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
|
probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
|
||||||
mHeaderPolicy);
|
mHeaderPolicy);
|
||||||
return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo(
|
return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo);
|
||||||
&updatedHistoricalInfo);
|
|
||||||
} else {
|
} else {
|
||||||
return originalProbabilityEntry->createEntryWithUpdatedProbability(
|
return *probabilityEntry;
|
||||||
probabilityEntry->getProbability());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
|
||||||
const int count = 10;
|
const int count = 10;
|
||||||
const int wordId = 100;
|
const int wordId = 100;
|
||||||
const HistoricalInfo historicalInfo(timestamp, level, count);
|
const HistoricalInfo historicalInfo(timestamp, level, count);
|
||||||
const ProbabilityEntry probabilityEntry(flag, NOT_A_PROBABILITY, &historicalInfo);
|
const ProbabilityEntry probabilityEntry(flag, &historicalInfo);
|
||||||
LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
||||||
const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId);
|
const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId);
|
||||||
EXPECT_EQ(flag, entry.getFlags());
|
EXPECT_EQ(flag, entry.getFlags());
|
||||||
|
|
|
@ -43,7 +43,7 @@ TEST(ProbabilityEntryTest, TestEncodeDecodeWithHistoricalInfo) {
|
||||||
const int count = 10;
|
const int count = 10;
|
||||||
|
|
||||||
const HistoricalInfo historicalInfo(timestamp, level, count);
|
const HistoricalInfo historicalInfo(timestamp, level, count);
|
||||||
const ProbabilityEntry entry(flag, NOT_A_PROBABILITY, &historicalInfo);
|
const ProbabilityEntry entry(flag, &historicalInfo);
|
||||||
|
|
||||||
const uint64_t encodedEntry = entry.encode(true /* hasHistoricalInfo */);
|
const uint64_t encodedEntry = entry.encode(true /* hasHistoricalInfo */);
|
||||||
EXPECT_EQ(0xF03FFFFFFF030Aull, encodedEntry);
|
EXPECT_EQ(0xF03FFFFFFF030Aull, encodedEntry);
|
||||||
|
|
Loading…
Reference in a new issue