am 20da4f07: Merge "Use enum to specify ngram type."

* commit '20da4f07be9cdf58835a79e619785b4cafd428ff':
  Use enum to specify ngram type.
This commit is contained in:
Keisuke Kuroyanagi 2014-11-25 10:39:04 +00:00 committed by Android Git Automerger
commit adb4f0ed20
14 changed files with 218 additions and 251 deletions

View file

@ -18,6 +18,8 @@
#include <algorithm> #include <algorithm>
#include "utils/ngram_utils.h"
namespace latinime { namespace latinime {
// Note that these are corresponding definitions in Java side in DictionaryHeader. // Note that these are corresponding definitions in Java side in DictionaryHeader.
@ -28,9 +30,11 @@ const char *const HeaderPolicy::REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY =
const char *const HeaderPolicy::IS_DECAYING_DICT_KEY = "USES_FORGETTING_CURVE"; const char *const HeaderPolicy::IS_DECAYING_DICT_KEY = "USES_FORGETTING_CURVE";
const char *const HeaderPolicy::DATE_KEY = "date"; const char *const HeaderPolicy::DATE_KEY = "date";
const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME"; const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME";
const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT"; const char *const HeaderPolicy::NGRAM_COUNT_KEYS[] =
const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT"; {"UNIGRAM_COUNT", "BIGRAM_COUNT", "TRIGRAM_COUNT"};
const char *const HeaderPolicy::TRIGRAM_COUNT_KEY = "TRIGRAM_COUNT"; const char *const HeaderPolicy::MAX_NGRAM_COUNT_KEYS[] =
{"MAX_UNIGRAM_ENTRY_COUNT", "MAX_BIGRAM_ENTRY_COUNT", "MAX_TRIGRAM_ENTRY_COUNT"};
const int HeaderPolicy::DEFAULT_MAX_NGRAM_COUNTS[] = {10000, 30000, 30000};
const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE"; const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE";
// Historical info is information that is needed to support decaying such as timestamp, level and // Historical info is information that is needed to support decaying such as timestamp, level and
// count. // count.
@ -39,18 +43,10 @@ const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY = const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
"FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID"; "FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT";
const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT";
const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT";
const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100; const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3; const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000;
const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000;
// Used for logging. Question mark is used to indicate that the key is not found. // Used for logging. Question mark is used to indicate that the key is not found.
void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue, void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue,
int outValueSize) const { int outValueSize) const {
@ -126,15 +122,22 @@ bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTim
return true; return true;
} }
namespace {
int getIndexFromNgramType(const NgramType ngramType) {
return static_cast<int>(ngramType);
}
} // namespace
void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
const EntryCounts &entryCounts, const int extendedRegionSize, const EntryCounts &entryCounts, const int extendedRegionSize,
DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const { DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const {
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY, for (const auto ngramType : AllNgramTypes::ASCENDING) {
entryCounts.getUnigramCount()); HeaderReadWriteUtils::setIntAttribute(outAttributeMap,
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY, NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)],
entryCounts.getBigramCount()); entryCounts.getNgramCount(ngramType));
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, TRIGRAM_COUNT_KEY, }
entryCounts.getTrigramCount());
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY, HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY,
extendedRegionSize); extendedRegionSize);
// Set the current time as the generation time. // Set the current time as the generation time.
@ -155,4 +158,25 @@ void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
return attributeMap; return attributeMap;
} }
/* static */ const EntryCounts HeaderPolicy::readNgramCounts() const {
MutableEntryCounters entryCounters;
for (const auto ngramType : AllNgramTypes::ASCENDING) {
const int entryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)], 0 /* defaultValue */);
entryCounters.setNgramCount(ngramType, entryCount);
}
return entryCounters.getEntryCounts();
}
/* static */ const EntryCounts HeaderPolicy::readMaxNgramCounts() const {
MutableEntryCounters entryCounters;
for (const auto ngramType : AllNgramTypes::ASCENDING) {
const int index = getIndexFromNgramType(ngramType);
const int maxEntryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
MAX_NGRAM_COUNT_KEYS[index], DEFAULT_MAX_NGRAM_COUNTS[index]);
entryCounters.setNgramCount(ngramType, maxEntryCount);
}
return entryCounters.getEntryCounts();
}
} // namespace latinime } // namespace latinime

View file

@ -46,12 +46,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
LAST_DECAYED_TIME_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), LAST_DECAYED_TIME_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()),
UNIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
BIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
TRIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)), EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
@ -59,12 +54,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue( mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY, &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)), DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Constructs header information using an attribute map. // Constructs header information using an attribute map.
@ -82,18 +71,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0), mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()),
mExtendedRegionSize(0),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)), &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue( mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY, &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)), DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Copy header information // Copy header information
@ -105,15 +89,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mRequiresGermanUmlautProcessing(headerPolicy->mRequiresGermanUmlautProcessing), mRequiresGermanUmlautProcessing(headerPolicy->mRequiresGermanUmlautProcessing),
mIsDecayingDict(headerPolicy->mIsDecayingDict), mIsDecayingDict(headerPolicy->mIsDecayingDict),
mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime), mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime),
mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount), mNgramCounts(headerPolicy->mNgramCounts),
mTrigramCount(headerPolicy->mTrigramCount), mMaxNgramCounts(headerPolicy->mMaxNgramCounts),
mExtendedRegionSize(headerPolicy->mExtendedRegionSize), mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords), mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
mForgettingCurveProbabilityValuesTableId( mForgettingCurveProbabilityValuesTableId(
headerPolicy->mForgettingCurveProbabilityValuesTableId), headerPolicy->mForgettingCurveProbabilityValuesTableId),
mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
mMaxBigramCount(headerPolicy->mMaxBigramCount),
mMaxTrigramCount(headerPolicy->mMaxTrigramCount),
mCodePointTable(headerPolicy->mCodePointTable) {} mCodePointTable(headerPolicy->mCodePointTable) {}
// Temporary dummy header. // Temporary dummy header.
@ -121,10 +102,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
: mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0), : mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0),
mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f), mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f),
mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false), mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mDate(0), mLastDecayedTime(0), mNgramCounts(), mMaxNgramCounts(),
mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0), mForgettingCurveProbabilityValuesTableId(0), mCodePointTable(nullptr) {}
mMaxTrigramCount(0), mCodePointTable(nullptr) {}
~HeaderPolicy() {} ~HeaderPolicy() {}
@ -186,16 +166,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mLastDecayedTime; return mLastDecayedTime;
} }
AK_FORCE_INLINE int getUnigramCount() const { AK_FORCE_INLINE const EntryCounts &getNgramCounts() const {
return mUnigramCount; return mNgramCounts;
} }
AK_FORCE_INLINE int getBigramCount() const { AK_FORCE_INLINE const EntryCounts getMaxNgramCounts() const {
return mBigramCount; return mMaxNgramCounts;
}
AK_FORCE_INLINE int getTrigramCount() const {
return mTrigramCount;
} }
AK_FORCE_INLINE int getExtendedRegionSize() const { AK_FORCE_INLINE int getExtendedRegionSize() const {
@ -219,18 +195,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mForgettingCurveProbabilityValuesTableId; return mForgettingCurveProbabilityValuesTableId;
} }
AK_FORCE_INLINE int getMaxUnigramCount() const {
return mMaxUnigramCount;
}
AK_FORCE_INLINE int getMaxBigramCount() const {
return mMaxBigramCount;
}
AK_FORCE_INLINE int getMaxTrigramCount() const {
return mMaxTrigramCount;
}
void readHeaderValueOrQuestionMark(const char *const key, void readHeaderValueOrQuestionMark(const char *const key,
int *outValue, int outValueSize) const; int *outValue, int outValueSize) const;
@ -262,24 +226,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const IS_DECAYING_DICT_KEY; static const char *const IS_DECAYING_DICT_KEY;
static const char *const DATE_KEY; static const char *const DATE_KEY;
static const char *const LAST_DECAYED_TIME_KEY; static const char *const LAST_DECAYED_TIME_KEY;
static const char *const UNIGRAM_COUNT_KEY; static const char *const NGRAM_COUNT_KEYS[];
static const char *const BIGRAM_COUNT_KEY; static const char *const MAX_NGRAM_COUNT_KEYS[];
static const char *const TRIGRAM_COUNT_KEY; static const int DEFAULT_MAX_NGRAM_COUNTS[];
static const char *const EXTENDED_REGION_SIZE_KEY; static const char *const EXTENDED_REGION_SIZE_KEY;
static const char *const HAS_HISTORICAL_INFO_KEY; static const char *const HAS_HISTORICAL_INFO_KEY;
static const char *const LOCALE_KEY; static const char *const LOCALE_KEY;
static const char *const FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY; static const char *const FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY;
static const char *const FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY; static const char *const FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY;
static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY; static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY;
static const char *const MAX_UNIGRAM_COUNT_KEY;
static const char *const MAX_BIGRAM_COUNT_KEY;
static const char *const MAX_TRIGRAM_COUNT_KEY;
static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE; static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID; static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
static const int DEFAULT_MAX_UNIGRAM_COUNT;
static const int DEFAULT_MAX_BIGRAM_COUNT;
static const int DEFAULT_MAX_TRIGRAM_COUNT;
const FormatUtils::FORMAT_VERSION mDictFormatVersion; const FormatUtils::FORMAT_VERSION mDictFormatVersion;
const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags; const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags;
@ -291,21 +249,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const bool mIsDecayingDict; const bool mIsDecayingDict;
const int mDate; const int mDate;
const int mLastDecayedTime; const int mLastDecayedTime;
const int mUnigramCount; const EntryCounts mNgramCounts;
const int mBigramCount; const EntryCounts mMaxNgramCounts;
const int mTrigramCount;
const int mExtendedRegionSize; const int mExtendedRegionSize;
const bool mHasHistoricalInfoOfWords; const bool mHasHistoricalInfoOfWords;
const int mForgettingCurveProbabilityValuesTableId; const int mForgettingCurveProbabilityValuesTableId;
const int mMaxUnigramCount;
const int mMaxBigramCount;
const int mMaxTrigramCount;
const int *const mCodePointTable; const int *const mCodePointTable;
const std::vector<int> readLocale() const; const std::vector<int> readLocale() const;
float readMultipleWordCostMultiplier() const; float readMultipleWordCostMultiplier() const;
bool readRequiresGermanUmlautProcessing() const; bool readRequiresGermanUmlautProcessing() const;
const EntryCounts readNgramCounts() const;
const EntryCounts readMaxNgramCounts() const;
static DictionaryHeaderStructurePolicy::AttributeMap createAttributeMapAndReadAllAttributes( static DictionaryHeaderStructurePolicy::AttributeMap createAttributeMapAndReadAllAttributes(
const uint8_t *const dictBuf); const uint8_t *const dictBuf);
}; };

View file

@ -303,7 +303,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) { &addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
mEntryCounters.incrementUnigramCount(); mEntryCounters.incrementNgramCount(NgramType::Unigram);
} }
if (unigramProperty->getShortcuts().size() > 0) { if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target. // Add shortcut target.
@ -397,7 +397,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos), if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, ngramProperty, &addedNewBigram)) { wordPos, ngramProperty, &addedNewBigram)) {
if (addedNewBigram) { if (addedNewBigram) {
mEntryCounters.incrementBigramCount(); mEntryCounters.incrementNgramCount(NgramType::Bigram);
} }
return true; return true;
} else { } else {
@ -438,7 +438,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry( if (mUpdatingHelper.removeNgramEntry(
PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) { PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
mEntryCounters.decrementBigramCount(); mEntryCounters.decrementNgramCount(NgramType::Bigram);
return true; return true;
} else { } else {
return false; return false;
@ -525,20 +525,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) { char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */; const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) { if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount()); snprintf(outResult, maxResultLength, "%d",
mEntryCounters.getNgramCount(NgramType::Unigram));
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) { } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount()); snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram));
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) { } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ? mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit( ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxUnigramCount()) : mHeaderPolicy->getMaxNgramCounts().getNgramCount(
NgramType::Unigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) { } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ? mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit( ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxBigramCount()) : mHeaderPolicy->getMaxNgramCounts().getNgramCount(
NgramType::Bigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} }
} }

View file

@ -76,8 +76,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy), &mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()), mWritingHelper(mBuffers.get()),
mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(), mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()),
mHeaderPolicy->getTrigramCount()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}; mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
virtual int getRootPosition() const { virtual int getRootPosition() const {

View file

@ -53,8 +53,8 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat
entryCounts, extendedRegionSize, &headerBuffer)) { entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. " AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, " "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, "
"extendedRegionSize: %d", false, entryCounts.getUnigramCount(), "extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram),
entryCounts.getBigramCount(), extendedRegionSize); entryCounts.getNgramCount(NgramType::Bigram), extendedRegionSize);
return false; return false;
} }
return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@ -73,9 +73,11 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
} }
BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
MutableEntryCounters entryCounters;
entryCounters.setNgramCount(NgramType::Unigram, unigramCount);
entryCounters.setNgramCount(NgramType::Bigram, bigramCount);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */), entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
0 /* extendedRegionSize */, &headerBuffer)) {
return false; return false;
} }
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@ -107,7 +109,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
} }
const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted
.getValidUnigramCount(); .getValidUnigramCount();
const int maxUnigramCount = headerPolicy->getMaxUnigramCount(); const int maxUnigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Unigram);
if (headerPolicy->isDecayingDict() && unigramCount > maxUnigramCount) { if (headerPolicy->isDecayingDict() && unigramCount > maxUnigramCount) {
if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter, maxUnigramCount)) { if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter, maxUnigramCount)) {
AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount, AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount,
@ -124,7 +126,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return false; return false;
} }
const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount(); const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount();
const int maxBigramCount = headerPolicy->getMaxBigramCount(); const int maxBigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Bigram);
if (headerPolicy->isDecayingDict() && bigramCount > maxBigramCount) { if (headerPolicy->isDecayingDict() && bigramCount > maxBigramCount) {
if (!truncateBigrams(maxBigramCount)) { if (!truncateBigrams(maxBigramCount)) {
AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount, maxBigramCount); AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount, maxBigramCount);

View file

@ -18,17 +18,13 @@
namespace latinime { namespace latinime {
// These counts are used to provide stable probabilities even if the user's input count is small. // Used to provide stable probabilities even if the user's input count is small.
const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_UNIGRAMS = 8192; const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNTS[] = {8192, 2, 2};
const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_BIGRAMS = 2;
const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_TRIGRAMS = 2;
// These are encoded backoff weights. // Encoded backoff weights.
// Note that we give positive value for trigrams that means the weight is more than 1. // Note that we give positive value for trigrams that means the weight is more than 1.
// TODO: Apply backoff for main dictionaries and quit giving a positive backoff weight. // TODO: Apply backoff for main dictionaries and quit giving a positive backoff weight.
const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS = -32; const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHTS[] = {-32, 0, 8};
const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS = 0;
const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS = 8;
// This value is used to remove too old entries from the dictionary. // This value is used to remove too old entries from the dictionary.
const int DynamicLanguageModelProbabilityUtils::DURATION_TO_DISCARD_ENTRY_IN_SECONDS = const int DynamicLanguageModelProbabilityUtils::DURATION_TO_DISCARD_ENTRY_IN_SECONDS =

View file

@ -21,6 +21,7 @@
#include "defines.h" #include "defines.h"
#include "suggest/core/dictionary/property/historical_info.h" #include "suggest/core/dictionary/property/historical_info.h"
#include "utils/ngram_utils.h"
#include "utils/time_keeper.h" #include "utils/time_keeper.h"
namespace latinime { namespace latinime {
@ -28,46 +29,14 @@ namespace latinime {
class DynamicLanguageModelProbabilityUtils { class DynamicLanguageModelProbabilityUtils {
public: public:
static float computeRawProbabilityFromCounts(const int count, const int contextCount, static float computeRawProbabilityFromCounts(const int count, const int contextCount,
const int matchedWordCountInContext) { const NgramType ngramType) {
int minCount = 0; const int minCount = ASSUMED_MIN_COUNTS[static_cast<int>(ngramType)];
switch (matchedWordCountInContext) {
case 1:
minCount = ASSUMED_MIN_COUNT_FOR_UNIGRAMS;
break;
case 2:
minCount = ASSUMED_MIN_COUNT_FOR_BIGRAMS;
break;
case 3:
minCount = ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
break;
default:
AKLOGE("computeRawProbabilityFromCounts is called with invalid "
"matchedWordCountInContext (%d).", matchedWordCountInContext);
ASSERT(false);
return 0.0f;
}
return static_cast<float>(count) / static_cast<float>(std::max(contextCount, minCount)); return static_cast<float>(count) / static_cast<float>(std::max(contextCount, minCount));
} }
static float backoff(const int ngramProbability, const int matchedWordCountInContext) { static float backoff(const int ngramProbability, const NgramType ngramType) {
int probability = NOT_A_PROBABILITY; const int probability =
ngramProbability + ENCODED_BACKOFF_WEIGHTS[static_cast<int>(ngramType)];
switch (matchedWordCountInContext) {
case 1:
probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
break;
case 2:
probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
break;
case 3:
probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
break;
default:
AKLOGE("backoff is called with invalid matchedWordCountInContext (%d).",
matchedWordCountInContext);
ASSERT(false);
return NOT_A_PROBABILITY;
}
return std::min(std::max(probability, NOT_A_PROBABILITY), MAX_PROBABILITY); return std::min(std::max(probability, NOT_A_PROBABILITY), MAX_PROBABILITY);
} }
@ -99,14 +68,8 @@ private:
static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 2, "Max supported Ngram is Trigram."); static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 2, "Max supported Ngram is Trigram.");
static const int ASSUMED_MIN_COUNT_FOR_UNIGRAMS; static const int ASSUMED_MIN_COUNTS[];
static const int ASSUMED_MIN_COUNT_FOR_BIGRAMS; static const int ENCODED_BACKOFF_WEIGHTS[];
static const int ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
static const int ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
static const int ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
static const int ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
static const int DURATION_TO_DISCARD_ENTRY_IN_SECONDS; static const int DURATION_TO_DISCARD_ENTRY_IN_SECONDS;
}; };

View file

@ -21,6 +21,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h"
#include "utils/ngram_utils.h"
namespace latinime { namespace latinime {
@ -89,16 +90,17 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
} }
contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount(); contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
} }
const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(i + 1);
const float rawProbability = const float rawProbability =
DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts( DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
historicalInfo->getCount(), contextCount, i + 1); historicalInfo->getCount(), contextCount, ngramType);
const int encodedRawProbability = const int encodedRawProbability =
ProbabilityUtils::encodeRawProbability(rawProbability); ProbabilityUtils::encodeRawProbability(rawProbability);
const int decayedProbability = const int decayedProbability =
DynamicLanguageModelProbabilityUtils::getDecayedProbability( DynamicLanguageModelProbabilityUtils::getDecayedProbability(
encodedRawProbability, *historicalInfo); encodedRawProbability, *historicalInfo);
probability = DynamicLanguageModelProbabilityUtils::backoff( probability = DynamicLanguageModelProbabilityUtils::backoff(
decayedProbability, i + 1 /* n */); decayedProbability, ngramType);
} else { } else {
probability = probabilityEntry.getProbability(); probability = probabilityEntry.getProbability();
} }
@ -198,18 +200,19 @@ bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCo
MutableEntryCounters *const outEntryCounters) { MutableEntryCounters *const outEntryCounters) {
for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) { for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
const int totalWordCount = prevWordCount + 1; const int totalWordCount = prevWordCount + 1;
if (currentEntryCounts.getNgramCount(totalWordCount) const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(totalWordCount);
<= maxEntryCounts.getNgramCount(totalWordCount)) { if (currentEntryCounts.getNgramCount(ngramType)
outEntryCounters->setNgramCount(totalWordCount, <= maxEntryCounts.getNgramCount(ngramType)) {
currentEntryCounts.getNgramCount(totalWordCount)); outEntryCounters->setNgramCount(ngramType,
currentEntryCounts.getNgramCount(ngramType));
continue; continue;
} }
int entryCount = 0; int entryCount = 0;
if (!turncateEntriesInSpecifiedLevel(headerPolicy, if (!turncateEntriesInSpecifiedLevel(headerPolicy,
maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) { maxEntryCounts.getNgramCount(ngramType), prevWordCount, &entryCount)) {
return false; return false;
} }
outEntryCounters->setNgramCount(totalWordCount, entryCount); outEntryCounters->setNgramCount(ngramType, entryCount);
} }
return true; return true;
} }
@ -246,7 +249,10 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
mGlobalCounters.updateMaxValueOfCounters( mGlobalCounters.updateMaxValueOfCounters(
updatedNgramProbabilityEntry.getHistoricalInfo()->getCount()); updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
if (!originalNgramProbabilityEntry.isValid()) { if (!originalNgramProbabilityEntry.isValid()) {
entryCountersToUpdate->incrementNgramCount(i + 2); // (i + 2) words are used in total because the prevWords consists of (i + 1) words when
// looking at its i-th element.
entryCountersToUpdate->incrementNgramCount(
NgramUtils::getNgramTypeFromWordCount(i + 2));
} }
} }
return true; return true;
@ -369,7 +375,8 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
} }
} }
} }
outEntryCounters->incrementNgramCount(prevWordCount + 1); outEntryCounters->incrementNgramCount(
NgramUtils::getNgramTypeFromWordCount(prevWordCount + 1));
if (!entry.hasNextLevelMap()) { if (!entry.hasNextLevelMap()) {
continue; continue;
} }
@ -402,7 +409,8 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
for (int i = 0; i < entryCountToRemove; ++i) { for (int i = 0; i < entryCountToRemove; ++i) {
const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
if (!removeNgramProbabilityEntry( if (!removeNgramProbabilityEntry(
WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) { WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount),
entryInfo.mKey)) {
return false; return false;
} }
} }

View file

@ -31,6 +31,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h"
#include "utils/ngram_utils.h"
namespace latinime { namespace latinime {
@ -215,7 +216,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) { &addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
mEntryCounters.incrementUnigramCount(); mEntryCounters.incrementNgramCount(NgramType::Unigram);
} }
if (unigramProperty->getShortcuts().size() > 0) { if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target. // Add shortcut target.
@ -263,7 +264,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
return false; return false;
} }
if (!ptNodeParams.representsNonWordInfo()) { if (!ptNodeParams.representsNonWordInfo()) {
mEntryCounters.decrementUnigramCount(); mEntryCounters.decrementNgramCount(NgramType::Unigram);
} }
return true; return true;
} }
@ -321,7 +322,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope
bool addedNewEntry = false; bool addedNewEntry = false;
if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) { if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) {
if (addedNewEntry) { if (addedNewEntry) {
mEntryCounters.incrementNgramCount(prevWordIds.size() + 1); mEntryCounters.incrementNgramCount(
NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1));
} }
return true; return true;
} else { } else {
@ -359,7 +361,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
return false; return false;
} }
if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) { if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) {
mEntryCounters.decrementNgramCount(prevWordIds.size()); mEntryCounters.decrementNgramCount(
NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1));
return true; return true;
} else { } else {
return false; return false;
@ -477,20 +480,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) { char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */; const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) { if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount()); snprintf(outResult, maxResultLength, "%d",
mEntryCounters.getNgramCount(NgramType::Unigram));
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) { } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount()); snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram));
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) { } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ? mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit( ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxUnigramCount()) : mHeaderPolicy->getMaxNgramCounts().getNgramCount(
NgramType::Unigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) { } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ? mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit( ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxBigramCount()) : mHeaderPolicy->getMaxNgramCounts().getNgramCount(
NgramType::Bigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} }
} }

View file

@ -51,8 +51,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mShortcutPolicy), &mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()), mWritingHelper(mBuffers.get()),
mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(), mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()),
mHeaderPolicy->getTrigramCount()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}; mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
AK_FORCE_INLINE int getRootPosition() const { AK_FORCE_INLINE int getRootPosition() const {

View file

@ -29,6 +29,7 @@
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
#include "suggest/policyimpl/dictionary/utils/file_utils.h" #include "suggest/policyimpl/dictionary/utils/file_utils.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
#include "utils/ngram_utils.h"
namespace latinime { namespace latinime {
@ -43,8 +44,9 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat
entryCounts, extendedRegionSize, &headerBuffer)) { entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. " AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d," "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d,"
"extendedRegionSize: %d", false, entryCounts.getUnigramCount(), "extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram),
entryCounts.getBigramCount(), entryCounts.getTrigramCount(), entryCounts.getNgramCount(NgramType::Bigram),
entryCounts.getNgramCount(NgramType::Trigram),
extendedRegionSize); extendedRegionSize);
return false; return false;
} }
@ -86,8 +88,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return false; return false;
} }
if (headerPolicy->isDecayingDict()) { if (headerPolicy->isDecayingDict()) {
const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(), const EntryCounts &maxEntryCounts = headerPolicy->getMaxNgramCounts();
headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries( if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy, outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
outEntryCounters)) { outEntryCounters)) {

View file

@ -20,6 +20,7 @@
#include <array> #include <array>
#include "defines.h" #include "defines.h"
#include "utils/ngram_utils.h"
namespace latinime { namespace latinime {
@ -28,34 +29,22 @@ class EntryCounts final {
public: public:
EntryCounts() : mEntryCounts({{0, 0, 0}}) {} EntryCounts() : mEntryCounts({{0, 0, 0}}) {}
EntryCounts(const int unigramCount, const int bigramCount, const int trigramCount)
: mEntryCounts({{unigramCount, bigramCount, trigramCount}}) {}
explicit EntryCounts(const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters) explicit EntryCounts(const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
: mEntryCounts(counters) {} : mEntryCounts(counters) {}
int getUnigramCount() const { int getNgramCount(const NgramType ngramType) const {
return mEntryCounts[0]; return mEntryCounts[static_cast<int>(ngramType)];
} }
int getBigramCount() const { const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &getCountArray() const {
return mEntryCounts[1]; return mEntryCounts;
}
int getTrigramCount() const {
return mEntryCounts[2];
}
int getNgramCount(const size_t n) const {
if (n < 1 || n > mEntryCounts.size()) {
return 0;
}
return mEntryCounts[n - 1];
} }
private: private:
DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts); DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);
// Counts from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram
// (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element)
const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounts; const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounts;
}; };
@ -65,68 +54,35 @@ class MutableEntryCounters final {
mEntryCounters.fill(0); mEntryCounters.fill(0);
} }
MutableEntryCounters(const int unigramCount, const int bigramCount, const int trigramCount) explicit MutableEntryCounters(
: mEntryCounters({{unigramCount, bigramCount, trigramCount}}) {} const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
: mEntryCounters(counters) {}
const EntryCounts getEntryCounts() const { const EntryCounts getEntryCounts() const {
return EntryCounts(mEntryCounters); return EntryCounts(mEntryCounters);
} }
int getUnigramCount() const { void incrementNgramCount(const NgramType ngramType) {
return mEntryCounters[0]; ++mEntryCounters[static_cast<int>(ngramType)];
} }
int getBigramCount() const { void decrementNgramCount(const NgramType ngramType) {
return mEntryCounters[1]; --mEntryCounters[static_cast<int>(ngramType)];
} }
int getTrigramCount() const { int getNgramCount(const NgramType ngramType) const {
return mEntryCounters[2]; return mEntryCounters[static_cast<int>(ngramType)];
} }
void incrementUnigramCount() { void setNgramCount(const NgramType ngramType, const int count) {
++mEntryCounters[0]; mEntryCounters[static_cast<int>(ngramType)] = count;
}
void decrementUnigramCount() {
ASSERT(mEntryCounters[0] != 0);
--mEntryCounters[0];
}
void incrementBigramCount() {
++mEntryCounters[1];
}
void decrementBigramCount() {
ASSERT(mEntryCounters[1] != 0);
--mEntryCounters[1];
}
void incrementNgramCount(const size_t n) {
if (n < 1 || n > mEntryCounters.size()) {
return;
}
++mEntryCounters[n - 1];
}
void decrementNgramCount(const size_t n) {
if (n < 1 || n > mEntryCounters.size()) {
return;
}
ASSERT(mEntryCounters[n - 1] != 0);
--mEntryCounters[n - 1];
}
void setNgramCount(const size_t n, const int count) {
if (n < 1 || n > mEntryCounters.size()) {
return;
}
mEntryCounters[n - 1] = count;
} }
private: private:
DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters); DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);
// Counters from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram
// (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element)
std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounters; std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounters;
}; };
} // namespace latinime } // namespace latinime

View file

@ -126,20 +126,13 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay, /* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay,
const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) { const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) {
if (entryCounts.getUnigramCount() const EntryCounts &maxNgramCounts = headerPolicy->getMaxNgramCounts();
>= getEntryCountHardLimit(headerPolicy->getMaxUnigramCount())) { for (const auto ngramType : AllNgramTypes::ASCENDING) {
// Unigram count exceeds the limit. if (entryCounts.getNgramCount(ngramType)
return true; >= getEntryCountHardLimit(maxNgramCounts.getNgramCount(ngramType))) {
} // Unigram count exceeds the limit.
if (entryCounts.getBigramCount() return true;
>= getEntryCountHardLimit(headerPolicy->getMaxBigramCount())) { }
// Bigram count exceeds the limit.
return true;
}
if (entryCounts.getTrigramCount()
>= getEntryCountHardLimit(headerPolicy->getMaxTrigramCount())) {
// Trigram count exceeds the limit.
return true;
} }
if (mindsBlockByDecay) { if (mindsBlockByDecay) {
return false; return false;

View file

@ -0,0 +1,62 @@
/*
* Copyright (C) 2014, The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_NGRAM_UTILS_H
#define LATINIME_NGRAM_UTILS_H
#include "defines.h"
namespace latinime {
enum class NgramType : int {
Unigram = 0,
Bigram = 1,
Trigram = 2,
NotANgramType = -1,
};
namespace AllNgramTypes {
// Use anonymous namespace to avoid ODR (One Definition Rule) violation.
namespace {
const NgramType ASCENDING[] = {
NgramType::Unigram, NgramType::Bigram, NgramType::Trigram
};
const NgramType DESCENDING[] = {
NgramType::Trigram, NgramType::Bigram, NgramType::Unigram
};
} // namespace
} // namespace AllNgramTypes
class NgramUtils final {
public:
static AK_FORCE_INLINE NgramType getNgramTypeFromWordCount(const int wordCount) {
// Max supported ngram is (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram.
if (wordCount <= 0 || wordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1) {
return NgramType::NotANgramType;
}
// Convert word count to 0-origin enum value.
return static_cast<NgramType>(wordCount - 1);
}
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(NgramUtils);
};
}
#endif /* LATINIME_NGRAM_UTILS_H */