Merge "Enable count based dynamic ngram language model for v403."

This commit is contained in:
Keisuke Kuroyanagi 2014-10-31 03:15:18 +00:00 committed by Android (Google) Code Review
commit c096100b01
7 changed files with 168 additions and 128 deletions

View file

@ -81,6 +81,9 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
} }
const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext( const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
mPrevWordIds, targetWordId, nullptr /* multiBigramMap */); mPrevWordIds, targetWordId, nullptr /* multiBigramMap */);
if (wordAttributes.getProbability() == NOT_A_PROBABILITY) {
return;
}
mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
wordAttributes.getProbability()); wordAttributes.getProbability());
} }

View file

@ -26,6 +26,8 @@ namespace latinime {
*/ */
class NgramListener { class NgramListener {
public: public:
// ngramProbability is always 0 for v403 decaying dictionary.
// TODO: Remove ngramProbability.
virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0; virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0;
virtual ~NgramListener() {}; virtual ~NgramListener() {};

View file

@ -19,11 +19,11 @@
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
namespace latinime { namespace latinime {
const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0; const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0;
const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1; const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1;
@ -39,7 +39,8 @@ bool LanguageModelDictContent::runGC(
} }
const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
const int wordId, const HeaderPolicy *const headerPolicy) const { const int wordId, const bool mustMatchAllPrevWords,
const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
int maxPrevWordCount = 0; int maxPrevWordCount = 0;
@ -53,7 +54,15 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
} }
const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) {
// The word should be treated as a invalid word.
return WordAttributes();
}
for (int i = maxPrevWordCount; i >= 0; --i) { for (int i = maxPrevWordCount; i >= 0; --i) {
if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) {
break;
}
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
if (!result.mIsValid) { if (!result.mIsValid) {
continue; continue;
@ -62,36 +71,39 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
int probability = NOT_A_PROBABILITY; int probability = NOT_A_PROBABILITY;
if (mHasHistoricalInfo) { if (mHasHistoricalInfo) {
const int rawProbability = ForgettingCurveUtils::decodeProbability( const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
probabilityEntry.getHistoricalInfo(), headerPolicy); int contextCount = 0;
if (rawProbability == NOT_A_PROBABILITY) {
// The entry should not be treated as a valid entry.
continue;
}
if (i == 0) { if (i == 0) {
// unigram // unigram
probability = rawProbability; contextCount = mGlobalCounters.getTotalCount();
} else { } else {
const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
if (!prevWordProbabilityEntry.isValid()) { if (!prevWordProbabilityEntry.isValid()) {
continue; continue;
} }
if (prevWordProbabilityEntry.representsBeginningOfSentence()) { if (prevWordProbabilityEntry.representsBeginningOfSentence()
probability = rawProbability; && historicalInfo->getCount() == 1) {
} else { // BoS ngram requires multiple contextCount.
const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability( continue;
prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
probability = std::min(MAX_PROBABILITY - prevWordRawProbability
+ rawProbability, MAX_PROBABILITY);
} }
contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
} }
const float rawProbability =
DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
historicalInfo->getCount(), contextCount, i + 1);
const int encodedRawProbability =
ProbabilityUtils::encodeRawProbability(rawProbability);
const int decayedProbability =
DynamicLanguageModelProbabilityUtils::getDecayedProbability(
encodedRawProbability, *historicalInfo);
probability = DynamicLanguageModelProbabilityUtils::backoff(
decayedProbability, i + 1 /* n */);
} else { } else {
probability = probabilityEntry.getProbability(); probability = probabilityEntry.getProbability();
} }
// TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in
// probabilityEntry. // probabilityEntry.
const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(), return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(),
unigramProbabilityEntry.isNotAWord(), unigramProbabilityEntry.isNotAWord(),
unigramProbabilityEntry.isPossiblyOffensive()); unigramProbabilityEntry.isPossiblyOffensive());
@ -167,7 +179,8 @@ void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
if (probabilityEntry.isValid()) { if (probabilityEntry.isValid()) {
const WordAttributes wordAttributes = getWordAttributes( const WordAttributes wordAttributes = getWordAttributes(
WordIdArrayView(*prevWordIds), wordId, headerPolicy); WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */,
headerPolicy);
outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId, outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
wordAttributes, probabilityEntry); wordAttributes, probabilityEntry);
} }
@ -231,7 +244,7 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
return false; return false;
} }
mGlobalCounters.updateMaxValueOfCounters( mGlobalCounters.updateMaxValueOfCounters(
updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount()); updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
if (!originalNgramProbabilityEntry.isValid()) { if (!originalNgramProbabilityEntry.isValid()) {
entryCountersToUpdate->incrementNgramCount(i + 2); entryCountersToUpdate->incrementNgramCount(i + 2);
} }
@ -242,10 +255,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
const ProbabilityEntry &originalProbabilityEntry, const bool isValid, const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const { const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(),
originalProbabilityEntry.getHistoricalInfo(), isValid ? 0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount()
DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY, + historicalInfo.getCount());
&historicalInfo, headerPolicy);
if (originalProbabilityEntry.isValid()) { if (originalProbabilityEntry.isValid()) {
return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo); return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
} else { } else {
@ -311,7 +323,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
const int prevWordCount, const HeaderPolicy *const headerPolicy, const int prevWordCount, const HeaderPolicy *const headerPolicy,
MutableEntryCounters *const outEntryCounters) { const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@ -328,33 +340,41 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
} }
continue; continue;
} }
if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence() if (mHasHistoricalInfo && probabilityEntry.isValid()) {
&& probabilityEntry.isValid()) { const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo();
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC(
probabilityEntry.getHistoricalInfo(), headerPolicy); *originalHistoricalInfo)) {
if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
// Update the entry.
const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
bitmapEntryIndex)) {
return false;
}
} else {
// Remove the entry. // Remove the entry.
if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
return false; return false;
} }
continue; continue;
} }
if (needsToHalveCounters) {
const int updatedCount = originalHistoricalInfo->getCount() / 2;
if (updatedCount == 0) {
// Remove the entry.
if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
return false;
}
continue;
}
const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(),
originalHistoricalInfo->getLevel(), updatedCount);
const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(),
&historicalInfoToSave);
if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
bitmapEntryIndex)) {
return false;
}
}
} }
if (!probabilityEntry.representsBeginningOfSentence()) { outEntryCounters->incrementNgramCount(prevWordCount + 1);
outEntryCounters->incrementNgramCount(prevWordCount + 1);
}
if (!entry.hasNextLevelMap()) { if (!entry.hasNextLevelMap()) {
continue; continue;
} }
if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
prevWordCount + 1, headerPolicy, outEntryCounters)) { prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) {
return false; return false;
} }
} }
@ -408,11 +428,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli
} }
const ProbabilityEntry probabilityEntry = const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
const int probability = (mHasHistoricalInfo) ? const int priority = mHasHistoricalInfo
ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction(
headerPolicy) : probabilityEntry.getProbability(); *probabilityEntry.getHistoricalInfo())
outEntryInfo->emplace_back(probability, : probabilityEntry.getProbability();
probabilityEntry.getHistoricalInfo()->getTimestamp(), outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(),
entry.key(), targetLevel, prevWordIds->data()); entry.key(), targetLevel, prevWordIds->data());
} }
return true; return true;
@ -420,11 +440,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli
bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
if (left.mProbability != right.mProbability) { if (left.mPriority != right.mPriority) {
return left.mProbability < right.mProbability; return left.mPriority < right.mPriority;
} }
if (left.mTimestamp != right.mTimestamp) { if (left.mCount != right.mCount) {
return left.mTimestamp > right.mTimestamp; return left.mCount < right.mCount;
} }
if (left.mKey != right.mKey) { if (left.mKey != right.mKey) {
return left.mKey < right.mKey; return left.mKey < right.mKey;
@ -441,10 +461,9 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
return false; return false;
} }
LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority,
const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds) const int count, const int key, const int prevWordCount, const int *const prevWordIds)
: mProbability(probability), mTimestamp(timestamp), mKey(key), : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) {
mPrevWordCount(prevWordCount) {
memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0]));
} }

View file

@ -151,13 +151,14 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent); const LanguageModelDictContent *const originalContent);
const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
const HeaderPolicy *const headerPolicy) const; const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const;
ProbabilityEntry getProbabilityEntry(const int wordId) const { ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId); return getNgramProbabilityEntry(WordIdArrayView(), wordId);
} }
bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) { bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) {
mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount());
return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry);
} }
@ -180,8 +181,15 @@ class LanguageModelDictContent {
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
MutableEntryCounters *const outEntryCounters) { MutableEntryCounters *const outEntryCounters) {
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
0 /* prevWordCount */, headerPolicy, outEntryCounters); 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(),
outEntryCounters)) {
return false;
}
if (mGlobalCounters.needsToHalveCounters()) {
mGlobalCounters.halveCounters();
}
return true;
} }
// entryCounts should be created by updateAllProbabilityEntries. // entryCounts should be created by updateAllProbabilityEntries.
@ -206,11 +214,12 @@ class LanguageModelDictContent {
DISALLOW_ASSIGNMENT_OPERATOR(Comparator); DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
}; };
EntryInfoToTurncate(const int probability, const int timestamp, const int key, EntryInfoToTurncate(const int priority, const int count, const int key,
const int prevWordCount, const int *const prevWordIds); const int prevWordCount, const int *const prevWordIds);
int mProbability; int mPriority;
int mTimestamp; // TODO: Remove.
int mCount;
int mKey; int mKey;
int mPrevWordCount; int mPrevWordCount;
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
@ -219,8 +228,6 @@ class LanguageModelDictContent {
DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
}; };
// TODO: Remove
static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;
static const int TRIE_MAP_BUFFER_INDEX; static const int TRIE_MAP_BUFFER_INDEX;
static const int GLOBAL_COUNTERS_BUFFER_INDEX; static const int GLOBAL_COUNTERS_BUFFER_INDEX;
@ -233,7 +240,8 @@ class LanguageModelDictContent {
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters); const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters,
MutableEntryCounters *const outEntryCounters);
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
const int maxEntryCount, const int targetLevel, int *const outEntryCount); const int maxEntryCount, const int targetLevel, int *const outEntryCount);
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,

View file

@ -63,6 +63,10 @@ class LanguageModelDictContentGlobalCounters {
mTotalCount += 1; mTotalCount += 1;
} }
void addToTotalCount(const int count) {
mTotalCount += count;
}
void updateMaxValueOfCounters(const int count) { void updateMaxValueOfCounters(const int count) {
mMaxValueOfCounters = std::max(count, mMaxValueOfCounters); mMaxValueOfCounters = std::max(count, mMaxValueOfCounters);
} }

View file

@ -110,7 +110,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
return WordAttributes(); return WordAttributes();
} }
return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId, return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId,
mHeaderPolicy); false /* mustMatchAllPrevWords */, mHeaderPolicy);
} }
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@ -118,18 +118,13 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) { if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
const ProbabilityEntry probabilityEntry = const WordAttributes wordAttributes =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId); mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId,
if (!probabilityEntry.isValid() || probabilityEntry.isBlacklisted() true /* mustMatchAllPrevWords */, mHeaderPolicy);
|| probabilityEntry.isNotAWord()) { if (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} }
if (mHeaderPolicy->hasHistoricalInfoOfWords()) { return wordAttributes.getProbability();
return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
mHeaderPolicy);
} else {
return probabilityEntry.getProbability();
}
} }
BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
@ -152,9 +147,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
continue; continue;
} }
const int probability = probabilityEntry.hasHistoricalInfo() ? const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability( 0 : probabilityEntry.getProbability();
probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
probabilityEntry.getProbability();
listener->onVisitEntry(probability, entry.getWordId()); listener->onVisitEntry(probability, entry.getWordId());
} }
} }
@ -386,25 +379,35 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext()."); AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext().");
return false; return false;
} }
if (!isValidWord) {
return true;
}
wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */); wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
} }
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
false /* tryLowerCaseSearch */); false /* tryLowerCaseSearch */);
if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID if (ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
&& ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) { if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) {
const UnigramProperty beginningOfSentenceUnigramProperty( const UnigramProperty beginningOfSentenceUnigramProperty(
true /* representsBeginningOfSentence */, true /* representsBeginningOfSentence */,
true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY, true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY,
HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */)); HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
&beginningOfSentenceUnigramProperty)) { &beginningOfSentenceUnigramProperty)) {
AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext()."); AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext().");
return false;
}
// Refresh word ids.
ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
}
// Update entries for beginning of sentence.
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(
prevWordIds.skip(1 /* n */), prevWordIds[0], true /* isVaild */, historicalInfo,
mHeaderPolicy, &mEntryCounters)) {
return false; return false;
} }
// Refresh word ids.
ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
} }
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds, if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds,
wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &mEntryCounters)) { wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &mEntryCounters)) {
@ -542,7 +545,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
} }
} }
const WordAttributes wordAttributes = languageModelDictContent->getWordAttributes( const WordAttributes wordAttributes = languageModelDictContent->getWordAttributes(
WordIdArrayView(), wordId, mHeaderPolicy); WordIdArrayView(), wordId, true /* mustMatchAllPrevWords */, mHeaderPolicy);
const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(wordId); const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(wordId);
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(), const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),

View file

@ -59,6 +59,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
super.setUp(); super.setUp();
mCurrentTime = 0; mCurrentTime = 0;
mDictFilesToBeDeleted.clear(); mDictFilesToBeDeleted.clear();
setCurrentTimeForTestMode(mCurrentTime);
} }
@Override @Override
@ -71,8 +72,8 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
super.tearDown(); super.tearDown();
} }
private static boolean supportsBeginningOfSentence(final int formatVersion) { private static boolean supportsCountBasedNgram(final int formatVersion) {
return formatVersion > FormatSpec.VERSION401; return formatVersion >= FormatSpec.VERSION4_DEV;
} }
private static boolean supportsNgram(final int formatVersion) { private static boolean supportsNgram(final int formatVersion) {
@ -142,19 +143,13 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
private File createEmptyDictionaryWithAttributeMapAndGetFile(final int formatVersion, private File createEmptyDictionaryWithAttributeMapAndGetFile(final int formatVersion,
final HashMap<String, String> attributeMap) { final HashMap<String, String> attributeMap) {
if (formatVersion == FormatSpec.VERSION4 try {
|| formatVersion == FormatSpec.VERSION4_ONLY_FOR_TESTING final File dictFile = createEmptyVer4DictionaryAndGetFile(formatVersion,
|| formatVersion == FormatSpec.VERSION4_DEV) { attributeMap);
try { mDictFilesToBeDeleted.add(dictFile);
final File dictFile = createEmptyVer4DictionaryAndGetFile(formatVersion, return dictFile;
attributeMap); } catch (final IOException e) {
mDictFilesToBeDeleted.add(dictFile); fail(e.toString());
return dictFile;
} catch (final IOException e) {
fail(e.toString());
}
} else {
fail("Dictionary format version " + formatVersion + " is not supported.");
} }
return null; return null;
} }
@ -263,12 +258,10 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
onInputWord(binaryDictionary, "a", false /* isValidWord */); onInputWord(binaryDictionary, "a", false /* isValidWord */);
assertTrue(binaryDictionary.isValidWord("a")); assertTrue(binaryDictionary.isValidWord("a"));
onInputWord(binaryDictionary, "b", true /* isValidWord */);
assertTrue(binaryDictionary.isValidWord("b"));
onInputWordWithPrevWord(binaryDictionary, "b", false /* isValidWord */, "a"); onInputWordWithPrevWord(binaryDictionary, "b", false /* isValidWord */, "a");
assertFalse(isValidBigram(binaryDictionary, "a", "b")); assertFalse(isValidBigram(binaryDictionary, "a", "b"));
onInputWordWithPrevWord(binaryDictionary, "b", false /* isValidWord */, "a"); onInputWordWithPrevWord(binaryDictionary, "b", false /* isValidWord */, "a");
assertTrue(binaryDictionary.isValidWord("b"));
assertTrue(isValidBigram(binaryDictionary, "a", "b")); assertTrue(isValidBigram(binaryDictionary, "a", "b"));
onInputWordWithPrevWord(binaryDictionary, "c", true /* isValidWord */, "a"); onInputWordWithPrevWord(binaryDictionary, "c", true /* isValidWord */, "a");
@ -284,16 +277,12 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
return; return;
} }
onInputWordWithPrevWords(binaryDictionary, "c", false /* isValidWord */, "b", "a"); onInputWordWithPrevWords(binaryDictionary, "c", true /* isValidWord */, "b", "a");
assertFalse(isValidTrigram(binaryDictionary, "a", "b", "c"));
assertFalse(isValidBigram(binaryDictionary, "b", "c"));
onInputWordWithPrevWords(binaryDictionary, "c", false /* isValidWord */, "b", "a");
assertTrue(isValidTrigram(binaryDictionary, "a", "b", "c")); assertTrue(isValidTrigram(binaryDictionary, "a", "b", "c"));
assertTrue(isValidBigram(binaryDictionary, "b", "c")); assertTrue(isValidBigram(binaryDictionary, "b", "c"));
onInputWordWithPrevWords(binaryDictionary, "d", false /* isValidWord */, "c", "b");
onInputWordWithPrevWords(binaryDictionary, "d", true /* isValidWord */, "b", "a"); assertFalse(isValidTrigram(binaryDictionary, "b", "c", "d"));
assertTrue(isValidTrigram(binaryDictionary, "a", "b", "d")); assertFalse(isValidBigram(binaryDictionary, "c", "d"));
assertTrue(isValidBigram(binaryDictionary, "b", "d"));
onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "b", "a"); onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "b", "a");
assertTrue(isValidTrigram(binaryDictionary, "a", "b", "cd")); assertTrue(isValidTrigram(binaryDictionary, "a", "b", "cd"));
@ -312,6 +301,13 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
onInputWord(binaryDictionary, "a", true /* isValidWord */); onInputWord(binaryDictionary, "a", true /* isValidWord */);
assertTrue(binaryDictionary.isValidWord("a")); assertTrue(binaryDictionary.isValidWord("a"));
forcePassingShortTime(binaryDictionary); forcePassingShortTime(binaryDictionary);
if (supportsCountBasedNgram(formatVersion)) {
// Count based ngram language model doesn't support decaying based on the elapsed time.
assertTrue(binaryDictionary.isValidWord("a"));
} else {
assertFalse(binaryDictionary.isValidWord("a"));
}
forcePassingLongTime(binaryDictionary);
assertFalse(binaryDictionary.isValidWord("a")); assertFalse(binaryDictionary.isValidWord("a"));
onInputWord(binaryDictionary, "a", true /* isValidWord */); onInputWord(binaryDictionary, "a", true /* isValidWord */);
@ -327,6 +323,12 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
onInputWordWithPrevWord(binaryDictionary, "b", true /* isValidWord */, "a"); onInputWordWithPrevWord(binaryDictionary, "b", true /* isValidWord */, "a");
assertTrue(isValidBigram(binaryDictionary, "a", "b")); assertTrue(isValidBigram(binaryDictionary, "a", "b"));
forcePassingShortTime(binaryDictionary); forcePassingShortTime(binaryDictionary);
if (supportsCountBasedNgram(formatVersion)) {
assertTrue(isValidBigram(binaryDictionary, "a", "b"));
} else {
assertFalse(isValidBigram(binaryDictionary, "a", "b"));
}
forcePassingLongTime(binaryDictionary);
assertFalse(isValidBigram(binaryDictionary, "a", "b")); assertFalse(isValidBigram(binaryDictionary, "a", "b"));
onInputWord(binaryDictionary, "a", true /* isValidWord */); onInputWord(binaryDictionary, "a", true /* isValidWord */);
@ -349,7 +351,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab"); onInputWordWithPrevWord(binaryDictionary, "bc", true /* isValidWord */, "ab");
onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab"); onInputWordWithPrevWords(binaryDictionary, "cd", true /* isValidWord */, "bc", "ab");
assertTrue(isValidTrigram(binaryDictionary, "ab", "bc", "cd")); assertTrue(isValidTrigram(binaryDictionary, "ab", "bc", "cd"));
forcePassingShortTime(binaryDictionary); forcePassingLongTime(binaryDictionary);
assertFalse(isValidTrigram(binaryDictionary, "ab", "bc", "cd")); assertFalse(isValidTrigram(binaryDictionary, "ab", "bc", "cd"));
onInputWord(binaryDictionary, "ab", true /* isValidWord */); onInputWord(binaryDictionary, "ab", true /* isValidWord */);
@ -540,7 +542,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
assertTrue(bigramCountBeforeGC > bigramCountAfterGC); assertTrue(bigramCountBeforeGC > bigramCountAfterGC);
} }
} }
forcePassingShortTime(binaryDictionary);
assertTrue(Integer.parseInt(binaryDictionary.getPropertyForGettingStats( assertTrue(Integer.parseInt(binaryDictionary.getPropertyForGettingStats(
BinaryDictionary.BIGRAM_COUNT_QUERY)) > 0); BinaryDictionary.BIGRAM_COUNT_QUERY)) > 0);
assertTrue(Integer.parseInt(binaryDictionary.getPropertyForGettingStats( assertTrue(Integer.parseInt(binaryDictionary.getPropertyForGettingStats(
@ -666,14 +668,17 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
assertEquals(toFormatVersion, binaryDictionary.getFormatVersion()); assertEquals(toFormatVersion, binaryDictionary.getFormatVersion());
assertTrue(binaryDictionary.isValidWord("aaa")); assertTrue(binaryDictionary.isValidWord("aaa"));
assertFalse(binaryDictionary.isValidWord("bbb")); assertFalse(binaryDictionary.isValidWord("bbb"));
assertTrue(binaryDictionary.getFrequency("aaa") < binaryDictionary.getFrequency("ccc")); if (supportsCountBasedNgram(toFormatVersion)) {
onInputWord(binaryDictionary, "bbb", false /* isValidWord */); assertTrue(binaryDictionary.getFrequency("aaa") < binaryDictionary.getFrequency("ccc"));
assertTrue(binaryDictionary.isValidWord("bbb")); onInputWord(binaryDictionary, "bbb", false /* isValidWord */);
assertTrue(binaryDictionary.isValidWord("bbb"));
}
assertTrue(isValidBigram(binaryDictionary, "aaa", "abc")); assertTrue(isValidBigram(binaryDictionary, "aaa", "abc"));
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb")); assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa"); if (supportsCountBasedNgram(toFormatVersion)) {
assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb")); onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa");
assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb"));
}
if (supportsNgram(toFormatVersion)) { if (supportsNgram(toFormatVersion)) {
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz")); assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def")); assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
@ -686,9 +691,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
public void testBeginningOfSentence() { public void testBeginningOfSentence() {
for (final int formatVersion : DICT_FORMAT_VERSIONS) { for (final int formatVersion : DICT_FORMAT_VERSIONS) {
if (supportsBeginningOfSentence(formatVersion)) { testBeginningOfSentence(formatVersion);
testBeginningOfSentence(formatVersion);
}
} }
} }
@ -716,10 +719,8 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
assertFalse(binaryDictionary.isValidNgram(beginningOfSentenceContext, "aaa")); assertFalse(binaryDictionary.isValidNgram(beginningOfSentenceContext, "aaa"));
assertFalse(binaryDictionary.isValidNgram(beginningOfSentenceContext, "bbb")); assertFalse(binaryDictionary.isValidNgram(beginningOfSentenceContext, "bbb"));
onInputWordWithBeginningOfSentenceContext(binaryDictionary, "aaa", true /* isValidWord */); onInputWordWithBeginningOfSentenceContext(binaryDictionary, "aaa", true /* isValidWord */);
assertFalse(binaryDictionary.isValidNgram(beginningOfSentenceContext, "aaa"));
onInputWordWithBeginningOfSentenceContext(binaryDictionary, "aaa", true /* isValidWord */); onInputWordWithBeginningOfSentenceContext(binaryDictionary, "aaa", true /* isValidWord */);
onInputWordWithBeginningOfSentenceContext(binaryDictionary, "bbb", true /* isValidWord */); onInputWordWithBeginningOfSentenceContext(binaryDictionary, "bbb", true /* isValidWord */);
assertFalse(binaryDictionary.isValidNgram(beginningOfSentenceContext, "bbb"));
onInputWordWithBeginningOfSentenceContext(binaryDictionary, "bbb", true /* isValidWord */); onInputWordWithBeginningOfSentenceContext(binaryDictionary, "bbb", true /* isValidWord */);
assertTrue(binaryDictionary.isValidNgram(beginningOfSentenceContext, "aaa")); assertTrue(binaryDictionary.isValidNgram(beginningOfSentenceContext, "aaa"));
assertTrue(binaryDictionary.isValidNgram(beginningOfSentenceContext, "bbb")); assertTrue(binaryDictionary.isValidNgram(beginningOfSentenceContext, "bbb"));