Use enum to specify ngram type.

Change-Id: Ie28768ceadcd7a2d940c57eb30be7d4c364e509f
This commit is contained in:
Keisuke Kuroyanagi 2014-11-20 15:27:30 +09:00
parent a94733cbca
commit 78212a6d3d
14 changed files with 218 additions and 251 deletions

View file

@ -18,6 +18,8 @@
#include <algorithm>
#include "utils/ngram_utils.h"
namespace latinime {
// Note that these are corresponding definitions in Java side in DictionaryHeader.
@ -28,9 +30,11 @@ const char *const HeaderPolicy::REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY =
const char *const HeaderPolicy::IS_DECAYING_DICT_KEY = "USES_FORGETTING_CURVE";
const char *const HeaderPolicy::DATE_KEY = "date";
const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME";
const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT";
const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT";
const char *const HeaderPolicy::TRIGRAM_COUNT_KEY = "TRIGRAM_COUNT";
const char *const HeaderPolicy::NGRAM_COUNT_KEYS[] =
{"UNIGRAM_COUNT", "BIGRAM_COUNT", "TRIGRAM_COUNT"};
const char *const HeaderPolicy::MAX_NGRAM_COUNT_KEYS[] =
{"MAX_UNIGRAM_ENTRY_COUNT", "MAX_BIGRAM_ENTRY_COUNT", "MAX_TRIGRAM_ENTRY_COUNT"};
const int HeaderPolicy::DEFAULT_MAX_NGRAM_COUNTS[] = {10000, 30000, 30000};
const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE";
// Historical info is information that is needed to support decaying such as timestamp, level and
// count.
@ -39,18 +43,10 @@ const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
"FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT";
const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT";
const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT";
const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000;
const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000;
// Used for logging. Question mark is used to indicate that the key is not found.
void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue,
int outValueSize) const {
@ -126,15 +122,22 @@ bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTim
return true;
}
namespace {
int getIndexFromNgramType(const NgramType ngramType) {
return static_cast<int>(ngramType);
}
} // namespace
void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
const EntryCounts &entryCounts, const int extendedRegionSize,
DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const {
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY,
entryCounts.getUnigramCount());
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY,
entryCounts.getBigramCount());
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, TRIGRAM_COUNT_KEY,
entryCounts.getTrigramCount());
for (const auto ngramType : AllNgramTypes::ASCENDING) {
HeaderReadWriteUtils::setIntAttribute(outAttributeMap,
NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)],
entryCounts.getNgramCount(ngramType));
}
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY,
extendedRegionSize);
// Set the current time as the generation time.
@ -155,4 +158,25 @@ void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
return attributeMap;
}
/* static */ const EntryCounts HeaderPolicy::readNgramCounts() const {
MutableEntryCounters entryCounters;
for (const auto ngramType : AllNgramTypes::ASCENDING) {
const int entryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)], 0 /* defaultValue */);
entryCounters.setNgramCount(ngramType, entryCount);
}
return entryCounters.getEntryCounts();
}
/* static */ const EntryCounts HeaderPolicy::readMaxNgramCounts() const {
MutableEntryCounters entryCounters;
for (const auto ngramType : AllNgramTypes::ASCENDING) {
const int index = getIndexFromNgramType(ngramType);
const int maxEntryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
MAX_NGRAM_COUNT_KEYS[index], DEFAULT_MAX_NGRAM_COUNTS[index]);
entryCounters.setNgramCount(ngramType, maxEntryCount);
}
return entryCounters.getEntryCounts();
}
} // namespace latinime

View file

@ -46,12 +46,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
LAST_DECAYED_TIME_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
UNIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
BIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
TRIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()),
mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
@ -59,12 +54,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Constructs header information using an attribute map.
@ -82,18 +71,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0),
mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()),
mExtendedRegionSize(0),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Copy header information
@ -105,15 +89,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mRequiresGermanUmlautProcessing(headerPolicy->mRequiresGermanUmlautProcessing),
mIsDecayingDict(headerPolicy->mIsDecayingDict),
mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime),
mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount),
mTrigramCount(headerPolicy->mTrigramCount),
mNgramCounts(headerPolicy->mNgramCounts),
mMaxNgramCounts(headerPolicy->mMaxNgramCounts),
mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
mForgettingCurveProbabilityValuesTableId(
headerPolicy->mForgettingCurveProbabilityValuesTableId),
mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
mMaxBigramCount(headerPolicy->mMaxBigramCount),
mMaxTrigramCount(headerPolicy->mMaxTrigramCount),
mCodePointTable(headerPolicy->mCodePointTable) {}
// Temporary dummy header.
@ -121,10 +102,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
: mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0),
mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f),
mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mTrigramCount(0),
mDate(0), mLastDecayedTime(0), mNgramCounts(), mMaxNgramCounts(),
mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0),
mMaxTrigramCount(0), mCodePointTable(nullptr) {}
mForgettingCurveProbabilityValuesTableId(0), mCodePointTable(nullptr) {}
~HeaderPolicy() {}
@ -186,16 +166,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mLastDecayedTime;
}
AK_FORCE_INLINE int getUnigramCount() const {
return mUnigramCount;
AK_FORCE_INLINE const EntryCounts &getNgramCounts() const {
return mNgramCounts;
}
AK_FORCE_INLINE int getBigramCount() const {
return mBigramCount;
}
AK_FORCE_INLINE int getTrigramCount() const {
return mTrigramCount;
AK_FORCE_INLINE const EntryCounts getMaxNgramCounts() const {
return mMaxNgramCounts;
}
AK_FORCE_INLINE int getExtendedRegionSize() const {
@ -219,18 +195,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mForgettingCurveProbabilityValuesTableId;
}
AK_FORCE_INLINE int getMaxUnigramCount() const {
return mMaxUnigramCount;
}
AK_FORCE_INLINE int getMaxBigramCount() const {
return mMaxBigramCount;
}
AK_FORCE_INLINE int getMaxTrigramCount() const {
return mMaxTrigramCount;
}
void readHeaderValueOrQuestionMark(const char *const key,
int *outValue, int outValueSize) const;
@ -262,24 +226,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const IS_DECAYING_DICT_KEY;
static const char *const DATE_KEY;
static const char *const LAST_DECAYED_TIME_KEY;
static const char *const UNIGRAM_COUNT_KEY;
static const char *const BIGRAM_COUNT_KEY;
static const char *const TRIGRAM_COUNT_KEY;
static const char *const NGRAM_COUNT_KEYS[];
static const char *const MAX_NGRAM_COUNT_KEYS[];
static const int DEFAULT_MAX_NGRAM_COUNTS[];
static const char *const EXTENDED_REGION_SIZE_KEY;
static const char *const HAS_HISTORICAL_INFO_KEY;
static const char *const LOCALE_KEY;
static const char *const FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY;
static const char *const FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY;
static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY;
static const char *const MAX_UNIGRAM_COUNT_KEY;
static const char *const MAX_BIGRAM_COUNT_KEY;
static const char *const MAX_TRIGRAM_COUNT_KEY;
static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
static const int DEFAULT_MAX_UNIGRAM_COUNT;
static const int DEFAULT_MAX_BIGRAM_COUNT;
static const int DEFAULT_MAX_TRIGRAM_COUNT;
const FormatUtils::FORMAT_VERSION mDictFormatVersion;
const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags;
@ -291,21 +249,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const bool mIsDecayingDict;
const int mDate;
const int mLastDecayedTime;
const int mUnigramCount;
const int mBigramCount;
const int mTrigramCount;
const EntryCounts mNgramCounts;
const EntryCounts mMaxNgramCounts;
const int mExtendedRegionSize;
const bool mHasHistoricalInfoOfWords;
const int mForgettingCurveProbabilityValuesTableId;
const int mMaxUnigramCount;
const int mMaxBigramCount;
const int mMaxTrigramCount;
const int *const mCodePointTable;
const std::vector<int> readLocale() const;
float readMultipleWordCostMultiplier() const;
bool readRequiresGermanUmlautProcessing() const;
const EntryCounts readNgramCounts() const;
const EntryCounts readMaxNgramCounts() const;
static DictionaryHeaderStructurePolicy::AttributeMap createAttributeMapAndReadAllAttributes(
const uint8_t *const dictBuf);
};

View file

@ -303,7 +303,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
mEntryCounters.incrementUnigramCount();
mEntryCounters.incrementNgramCount(NgramType::Unigram);
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@ -397,7 +397,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, ngramProperty, &addedNewBigram)) {
if (addedNewBigram) {
mEntryCounters.incrementBigramCount();
mEntryCounters.incrementNgramCount(NgramType::Bigram);
}
return true;
} else {
@ -438,7 +438,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry(
PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
mEntryCounters.decrementBigramCount();
mEntryCounters.decrementNgramCount(NgramType::Bigram);
return true;
} else {
return false;
@ -525,20 +525,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
snprintf(outResult, maxResultLength, "%d",
mEntryCounters.getNgramCount(NgramType::Unigram));
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram));
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxUnigramCount()) :
mHeaderPolicy->getMaxNgramCounts().getNgramCount(
NgramType::Unigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxBigramCount()) :
mHeaderPolicy->getMaxNgramCounts().getNgramCount(
NgramType::Bigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
}

View file

@ -76,8 +76,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
mHeaderPolicy->getTrigramCount()),
mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
virtual int getRootPosition() const {

View file

@ -53,8 +53,8 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat
entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, "
"extendedRegionSize: %d", false, entryCounts.getUnigramCount(),
entryCounts.getBigramCount(), extendedRegionSize);
"extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram),
entryCounts.getNgramCount(NgramType::Bigram), extendedRegionSize);
return false;
}
return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@ -73,9 +73,11 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
}
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
MutableEntryCounters entryCounters;
entryCounters.setNgramCount(NgramType::Unigram, unigramCount);
entryCounters.setNgramCount(NgramType::Bigram, bigramCount);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */),
0 /* extendedRegionSize */, &headerBuffer)) {
entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
return false;
}
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@ -107,7 +109,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
}
const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted
.getValidUnigramCount();
const int maxUnigramCount = headerPolicy->getMaxUnigramCount();
const int maxUnigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Unigram);
if (headerPolicy->isDecayingDict() && unigramCount > maxUnigramCount) {
if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter, maxUnigramCount)) {
AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount,
@ -124,7 +126,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return false;
}
const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount();
const int maxBigramCount = headerPolicy->getMaxBigramCount();
const int maxBigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Bigram);
if (headerPolicy->isDecayingDict() && bigramCount > maxBigramCount) {
if (!truncateBigrams(maxBigramCount)) {
AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount, maxBigramCount);

View file

@ -18,17 +18,13 @@
namespace latinime {
// These counts are used to provide stable probabilities even if the user's input count is small.
const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_UNIGRAMS = 8192;
const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_BIGRAMS = 2;
const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_TRIGRAMS = 2;
// Used to provide stable probabilities even if the user's input count is small.
const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNTS[] = {8192, 2, 2};
// These are encoded backoff weights.
// Encoded backoff weights.
// Note that we give positive value for trigrams that means the weight is more than 1.
// TODO: Apply backoff for main dictionaries and quit giving a positive backoff weight.
const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS = -32;
const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS = 0;
const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS = 8;
const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHTS[] = {-32, 0, 8};
// This value is used to remove too old entries from the dictionary.
const int DynamicLanguageModelProbabilityUtils::DURATION_TO_DISCARD_ENTRY_IN_SECONDS =

View file

@ -21,6 +21,7 @@
#include "defines.h"
#include "suggest/core/dictionary/property/historical_info.h"
#include "utils/ngram_utils.h"
#include "utils/time_keeper.h"
namespace latinime {
@ -28,46 +29,14 @@ namespace latinime {
class DynamicLanguageModelProbabilityUtils {
public:
static float computeRawProbabilityFromCounts(const int count, const int contextCount,
const int matchedWordCountInContext) {
int minCount = 0;
switch (matchedWordCountInContext) {
case 1:
minCount = ASSUMED_MIN_COUNT_FOR_UNIGRAMS;
break;
case 2:
minCount = ASSUMED_MIN_COUNT_FOR_BIGRAMS;
break;
case 3:
minCount = ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
break;
default:
AKLOGE("computeRawProbabilityFromCounts is called with invalid "
"matchedWordCountInContext (%d).", matchedWordCountInContext);
ASSERT(false);
return 0.0f;
}
const NgramType ngramType) {
const int minCount = ASSUMED_MIN_COUNTS[static_cast<int>(ngramType)];
return static_cast<float>(count) / static_cast<float>(std::max(contextCount, minCount));
}
static float backoff(const int ngramProbability, const int matchedWordCountInContext) {
int probability = NOT_A_PROBABILITY;
switch (matchedWordCountInContext) {
case 1:
probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
break;
case 2:
probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
break;
case 3:
probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
break;
default:
AKLOGE("backoff is called with invalid matchedWordCountInContext (%d).",
matchedWordCountInContext);
ASSERT(false);
return NOT_A_PROBABILITY;
}
static float backoff(const int ngramProbability, const NgramType ngramType) {
const int probability =
ngramProbability + ENCODED_BACKOFF_WEIGHTS[static_cast<int>(ngramType)];
return std::min(std::max(probability, NOT_A_PROBABILITY), MAX_PROBABILITY);
}
@ -99,14 +68,8 @@ private:
static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 2, "Max supported Ngram is Trigram.");
static const int ASSUMED_MIN_COUNT_FOR_UNIGRAMS;
static const int ASSUMED_MIN_COUNT_FOR_BIGRAMS;
static const int ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
static const int ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
static const int ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
static const int ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
static const int ASSUMED_MIN_COUNTS[];
static const int ENCODED_BACKOFF_WEIGHTS[];
static const int DURATION_TO_DISCARD_ENTRY_IN_SECONDS;
};

View file

@ -21,6 +21,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
#include "utils/ngram_utils.h"
namespace latinime {
@ -89,16 +90,17 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
}
contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
}
const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(i + 1);
const float rawProbability =
DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
historicalInfo->getCount(), contextCount, i + 1);
historicalInfo->getCount(), contextCount, ngramType);
const int encodedRawProbability =
ProbabilityUtils::encodeRawProbability(rawProbability);
const int decayedProbability =
DynamicLanguageModelProbabilityUtils::getDecayedProbability(
encodedRawProbability, *historicalInfo);
probability = DynamicLanguageModelProbabilityUtils::backoff(
decayedProbability, i + 1 /* n */);
decayedProbability, ngramType);
} else {
probability = probabilityEntry.getProbability();
}
@ -198,18 +200,19 @@ bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCo
MutableEntryCounters *const outEntryCounters) {
for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
const int totalWordCount = prevWordCount + 1;
if (currentEntryCounts.getNgramCount(totalWordCount)
<= maxEntryCounts.getNgramCount(totalWordCount)) {
outEntryCounters->setNgramCount(totalWordCount,
currentEntryCounts.getNgramCount(totalWordCount));
const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(totalWordCount);
if (currentEntryCounts.getNgramCount(ngramType)
<= maxEntryCounts.getNgramCount(ngramType)) {
outEntryCounters->setNgramCount(ngramType,
currentEntryCounts.getNgramCount(ngramType));
continue;
}
int entryCount = 0;
if (!turncateEntriesInSpecifiedLevel(headerPolicy,
maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
maxEntryCounts.getNgramCount(ngramType), prevWordCount, &entryCount)) {
return false;
}
outEntryCounters->setNgramCount(totalWordCount, entryCount);
outEntryCounters->setNgramCount(ngramType, entryCount);
}
return true;
}
@ -246,7 +249,10 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
mGlobalCounters.updateMaxValueOfCounters(
updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
if (!originalNgramProbabilityEntry.isValid()) {
entryCountersToUpdate->incrementNgramCount(i + 2);
// (i + 2) words are used in total because the prevWords consists of (i + 1) words when
// looking at its i-th element.
entryCountersToUpdate->incrementNgramCount(
NgramUtils::getNgramTypeFromWordCount(i + 2));
}
}
return true;
@ -369,7 +375,8 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
}
}
}
outEntryCounters->incrementNgramCount(prevWordCount + 1);
outEntryCounters->incrementNgramCount(
NgramUtils::getNgramTypeFromWordCount(prevWordCount + 1));
if (!entry.hasNextLevelMap()) {
continue;
}
@ -402,7 +409,8 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
for (int i = 0; i < entryCountToRemove; ++i) {
const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
if (!removeNgramProbabilityEntry(
WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) {
WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount),
entryInfo.mKey)) {
return false;
}
}

View file

@ -31,6 +31,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
#include "utils/ngram_utils.h"
namespace latinime {
@ -215,7 +216,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
mEntryCounters.incrementUnigramCount();
mEntryCounters.incrementNgramCount(NgramType::Unigram);
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@ -263,7 +264,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
return false;
}
if (!ptNodeParams.representsNonWordInfo()) {
mEntryCounters.decrementUnigramCount();
mEntryCounters.decrementNgramCount(NgramType::Unigram);
}
return true;
}
@ -321,7 +322,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope
bool addedNewEntry = false;
if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) {
if (addedNewEntry) {
mEntryCounters.incrementNgramCount(prevWordIds.size() + 1);
mEntryCounters.incrementNgramCount(
NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1));
}
return true;
} else {
@ -359,7 +361,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
return false;
}
if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) {
mEntryCounters.decrementNgramCount(prevWordIds.size());
mEntryCounters.decrementNgramCount(
NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1));
return true;
} else {
return false;
@ -477,20 +480,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
snprintf(outResult, maxResultLength, "%d",
mEntryCounters.getNgramCount(NgramType::Unigram));
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram));
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxUnigramCount()) :
mHeaderPolicy->getMaxNgramCounts().getNgramCount(
NgramType::Unigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxBigramCount()) :
mHeaderPolicy->getMaxNgramCounts().getNgramCount(
NgramType::Bigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
}

View file

@ -51,8 +51,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
mHeaderPolicy->getTrigramCount()),
mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
AK_FORCE_INLINE int getRootPosition() const {

View file

@ -29,6 +29,7 @@
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
#include "suggest/policyimpl/dictionary/utils/file_utils.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
#include "utils/ngram_utils.h"
namespace latinime {
@ -43,8 +44,9 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat
entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d,"
"extendedRegionSize: %d", false, entryCounts.getUnigramCount(),
entryCounts.getBigramCount(), entryCounts.getTrigramCount(),
"extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram),
entryCounts.getNgramCount(NgramType::Bigram),
entryCounts.getNgramCount(NgramType::Trigram),
extendedRegionSize);
return false;
}
@ -86,8 +88,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return false;
}
if (headerPolicy->isDecayingDict()) {
const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(),
headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
const EntryCounts &maxEntryCounts = headerPolicy->getMaxNgramCounts();
if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
outEntryCounters)) {

View file

@ -20,6 +20,7 @@
#include <array>
#include "defines.h"
#include "utils/ngram_utils.h"
namespace latinime {
@ -28,34 +29,22 @@ class EntryCounts final {
public:
EntryCounts() : mEntryCounts({{0, 0, 0}}) {}
EntryCounts(const int unigramCount, const int bigramCount, const int trigramCount)
: mEntryCounts({{unigramCount, bigramCount, trigramCount}}) {}
explicit EntryCounts(const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
: mEntryCounts(counters) {}
int getUnigramCount() const {
return mEntryCounts[0];
int getNgramCount(const NgramType ngramType) const {
return mEntryCounts[static_cast<int>(ngramType)];
}
int getBigramCount() const {
return mEntryCounts[1];
}
int getTrigramCount() const {
return mEntryCounts[2];
}
int getNgramCount(const size_t n) const {
if (n < 1 || n > mEntryCounts.size()) {
return 0;
}
return mEntryCounts[n - 1];
const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &getCountArray() const {
return mEntryCounts;
}
private:
DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);
// Counts from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram
// (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element)
const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounts;
};
@ -65,68 +54,35 @@ class MutableEntryCounters final {
mEntryCounters.fill(0);
}
MutableEntryCounters(const int unigramCount, const int bigramCount, const int trigramCount)
: mEntryCounters({{unigramCount, bigramCount, trigramCount}}) {}
explicit MutableEntryCounters(
const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
: mEntryCounters(counters) {}
const EntryCounts getEntryCounts() const {
return EntryCounts(mEntryCounters);
}
int getUnigramCount() const {
return mEntryCounters[0];
void incrementNgramCount(const NgramType ngramType) {
++mEntryCounters[static_cast<int>(ngramType)];
}
int getBigramCount() const {
return mEntryCounters[1];
void decrementNgramCount(const NgramType ngramType) {
--mEntryCounters[static_cast<int>(ngramType)];
}
int getTrigramCount() const {
return mEntryCounters[2];
int getNgramCount(const NgramType ngramType) const {
return mEntryCounters[static_cast<int>(ngramType)];
}
void incrementUnigramCount() {
++mEntryCounters[0];
}
void decrementUnigramCount() {
ASSERT(mEntryCounters[0] != 0);
--mEntryCounters[0];
}
void incrementBigramCount() {
++mEntryCounters[1];
}
void decrementBigramCount() {
ASSERT(mEntryCounters[1] != 0);
--mEntryCounters[1];
}
void incrementNgramCount(const size_t n) {
if (n < 1 || n > mEntryCounters.size()) {
return;
}
++mEntryCounters[n - 1];
}
void decrementNgramCount(const size_t n) {
if (n < 1 || n > mEntryCounters.size()) {
return;
}
ASSERT(mEntryCounters[n - 1] != 0);
--mEntryCounters[n - 1];
}
void setNgramCount(const size_t n, const int count) {
if (n < 1 || n > mEntryCounters.size()) {
return;
}
mEntryCounters[n - 1] = count;
void setNgramCount(const NgramType ngramType, const int count) {
mEntryCounters[static_cast<int>(ngramType)] = count;
}
private:
DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);
// Counters from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram
// (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element)
std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounters;
};
} // namespace latinime

View file

@ -126,20 +126,13 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay,
const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) {
if (entryCounts.getUnigramCount()
>= getEntryCountHardLimit(headerPolicy->getMaxUnigramCount())) {
const EntryCounts &maxNgramCounts = headerPolicy->getMaxNgramCounts();
for (const auto ngramType : AllNgramTypes::ASCENDING) {
if (entryCounts.getNgramCount(ngramType)
>= getEntryCountHardLimit(maxNgramCounts.getNgramCount(ngramType))) {
// Unigram count exceeds the limit.
return true;
}
if (entryCounts.getBigramCount()
>= getEntryCountHardLimit(headerPolicy->getMaxBigramCount())) {
// Bigram count exceeds the limit.
return true;
}
if (entryCounts.getTrigramCount()
>= getEntryCountHardLimit(headerPolicy->getMaxTrigramCount())) {
// Trigram count exceeds the limit.
return true;
}
if (mindsBlockByDecay) {
return false;

View file

@ -0,0 +1,62 @@
/*
* Copyright (C) 2014, The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_NGRAM_UTILS_H
#define LATINIME_NGRAM_UTILS_H
#include "defines.h"
namespace latinime {
enum class NgramType : int {
Unigram = 0,
Bigram = 1,
Trigram = 2,
NotANgramType = -1,
};
namespace AllNgramTypes {
// Use anonymous namespace to avoid ODR (One Definition Rule) violation.
namespace {
const NgramType ASCENDING[] = {
NgramType::Unigram, NgramType::Bigram, NgramType::Trigram
};
const NgramType DESCENDING[] = {
NgramType::Trigram, NgramType::Bigram, NgramType::Unigram
};
} // namespace
} // namespace AllNgramTypes
class NgramUtils final {
public:
static AK_FORCE_INLINE NgramType getNgramTypeFromWordCount(const int wordCount) {
// Max supported ngram is (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram.
if (wordCount <= 0 || wordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1) {
return NgramType::NotANgramType;
}
// Convert word count to 0-origin enum value.
return static_cast<NgramType>(wordCount - 1);
}
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(NgramUtils);
};
}
#endif /* LATINIME_NGRAM_UTILS_H */