Merge "Add BoS flag in probability entry."
This commit is contained in:
commit
ace03d7919
6 changed files with 30 additions and 26 deletions
|
@ -43,14 +43,13 @@ class ProbabilityEntry {
|
|||
: mFlags(flags), mProbability(probability), mHistoricalInfo() {}
|
||||
|
||||
// Entry with historical information.
|
||||
ProbabilityEntry(const int flags, const int probability,
|
||||
const HistoricalInfo *const historicalInfo)
|
||||
: mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}
|
||||
ProbabilityEntry(const int flags, const HistoricalInfo *const historicalInfo)
|
||||
: mFlags(flags), mProbability(NOT_A_PROBABILITY), mHistoricalInfo(*historicalInfo) {}
|
||||
|
||||
// Create from unigram property.
|
||||
// TODO: Set flags.
|
||||
ProbabilityEntry(const UnigramProperty *const unigramProperty)
|
||||
: mFlags(0), mProbability(unigramProperty->getProbability()),
|
||||
: mFlags(createFlags(unigramProperty->representsBeginningOfSentence())),
|
||||
mProbability(unigramProperty->getProbability()),
|
||||
mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
|
||||
unigramProperty->getCount()) {}
|
||||
|
||||
|
@ -61,15 +60,6 @@ class ProbabilityEntry {
|
|||
mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
|
||||
bigramProperty->getCount()) {}
|
||||
|
||||
const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
|
||||
return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
|
||||
}
|
||||
|
||||
const ProbabilityEntry createEntryWithUpdatedHistoricalInfo(
|
||||
const HistoricalInfo *const historicalInfo) const {
|
||||
return ProbabilityEntry(mFlags, mProbability, historicalInfo);
|
||||
}
|
||||
|
||||
bool isValid() const {
|
||||
return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
|
||||
}
|
||||
|
@ -78,7 +68,7 @@ class ProbabilityEntry {
|
|||
return mHistoricalInfo.isValid();
|
||||
}
|
||||
|
||||
int getFlags() const {
|
||||
uint8_t getFlags() const {
|
||||
return mFlags;
|
||||
}
|
||||
|
||||
|
@ -90,6 +80,10 @@ class ProbabilityEntry {
|
|||
return &mHistoricalInfo;
|
||||
}
|
||||
|
||||
bool representsBeginningOfSentence() const {
|
||||
return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0;
|
||||
}
|
||||
|
||||
uint64_t encode(const bool hasHistoricalInfo) const {
|
||||
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
|
||||
if (hasHistoricalInfo) {
|
||||
|
@ -123,7 +117,7 @@ class ProbabilityEntry {
|
|||
const int count = readFromEncodedEntry(encodedEntry,
|
||||
Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */);
|
||||
const HistoricalInfo historicalInfo(timestamp, level, count);
|
||||
return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
|
||||
return ProbabilityEntry(flags, &historicalInfo);
|
||||
} else {
|
||||
const int flags = readFromEncodedEntry(encodedEntry,
|
||||
Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
|
||||
|
@ -138,7 +132,7 @@ class ProbabilityEntry {
|
|||
// Copy constructor is public to use this class as a type of return value.
|
||||
DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry);
|
||||
|
||||
const int mFlags;
|
||||
const uint8_t mFlags;
|
||||
const int mProbability;
|
||||
const HistoricalInfo mHistoricalInfo;
|
||||
|
||||
|
@ -146,6 +140,14 @@ class ProbabilityEntry {
|
|||
return static_cast<int>(
|
||||
(encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
|
||||
}
|
||||
|
||||
static uint8_t createFlags(const bool representsBeginningOfSentence) {
|
||||
uint8_t flags = 0;
|
||||
if (representsBeginningOfSentence) {
|
||||
flags ^= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
|
||||
}
|
||||
return flags;
|
||||
}
|
||||
};
|
||||
} // namespace latinime
|
||||
#endif /* LATINIME_PROBABILITY_ENTRY_H */
|
||||
|
|
|
@ -54,6 +54,8 @@ const int Ver4DictConstants::TIME_STAMP_FIELD_SIZE = 4;
|
|||
const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1;
|
||||
const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1;
|
||||
|
||||
const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1;
|
||||
|
||||
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE = 16;
|
||||
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE = 4;
|
||||
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "defines.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace latinime {
|
||||
|
||||
|
@ -48,6 +49,8 @@ class Ver4DictConstants {
|
|||
static const int TIME_STAMP_FIELD_SIZE;
|
||||
static const int WORD_LEVEL_FIELD_SIZE;
|
||||
static const int WORD_COUNT_FIELD_SIZE;
|
||||
// Flags in probability entry.
|
||||
static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
|
||||
|
||||
static const int BIGRAM_ADDRESS_TABLE_BLOCK_SIZE;
|
||||
static const int BIGRAM_ADDRESS_TABLE_DATA_SIZE;
|
||||
|
|
|
@ -164,8 +164,8 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
|
|||
if (originalProbabilityEntry.hasHistoricalInfo()) {
|
||||
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
|
||||
originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
|
||||
const ProbabilityEntry probabilityEntry =
|
||||
originalProbabilityEntry.createEntryWithUpdatedHistoricalInfo(&historicalInfo);
|
||||
const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
|
||||
&historicalInfo);
|
||||
if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
|
||||
toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
|
||||
AKLOGE("Cannot write updated probability entry. terminalId: %d",
|
||||
|
@ -383,18 +383,15 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
|
|||
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
|
||||
const ProbabilityEntry *const originalProbabilityEntry,
|
||||
const ProbabilityEntry *const probabilityEntry) const {
|
||||
// TODO: Consolidate historical info and probability.
|
||||
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
|
||||
const HistoricalInfo updatedHistoricalInfo =
|
||||
ForgettingCurveUtils::createUpdatedHistoricalInfo(
|
||||
originalProbabilityEntry->getHistoricalInfo(),
|
||||
probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
|
||||
mHeaderPolicy);
|
||||
return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo(
|
||||
&updatedHistoricalInfo);
|
||||
return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo);
|
||||
} else {
|
||||
return originalProbabilityEntry->createEntryWithUpdatedProbability(
|
||||
probabilityEntry->getProbability());
|
||||
return *probabilityEntry;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
|
|||
const int count = 10;
|
||||
const int wordId = 100;
|
||||
const HistoricalInfo historicalInfo(timestamp, level, count);
|
||||
const ProbabilityEntry probabilityEntry(flag, NOT_A_PROBABILITY, &historicalInfo);
|
||||
const ProbabilityEntry probabilityEntry(flag, &historicalInfo);
|
||||
LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
|
||||
const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId);
|
||||
EXPECT_EQ(flag, entry.getFlags());
|
||||
|
|
|
@ -43,7 +43,7 @@ TEST(ProbabilityEntryTest, TestEncodeDecodeWithHistoricalInfo) {
|
|||
const int count = 10;
|
||||
|
||||
const HistoricalInfo historicalInfo(timestamp, level, count);
|
||||
const ProbabilityEntry entry(flag, NOT_A_PROBABILITY, &historicalInfo);
|
||||
const ProbabilityEntry entry(flag, &historicalInfo);
|
||||
|
||||
const uint64_t encodedEntry = entry.encode(true /* hasHistoricalInfo */);
|
||||
EXPECT_EQ(0xF03FFFFFFF030Aull, encodedEntry);
|
||||
|
|
Loading…
Reference in a new issue