Support ngram entry migration.

Bug: 14425059
Change-Id: I98cb9fa303af2d93a0a3512e8732231c564e3c5d
This commit is contained in:
Keisuke Kuroyanagi 2014-10-22 11:31:16 +09:00
parent 0b8bb0c21b
commit c9865785f4
7 changed files with 146 additions and 43 deletions

View file

@ -629,8 +629,7 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j
} }
} while (token != 0); } while (token != 0);
// Add bigrams. // Add ngrams.
// TODO: Support ngrams.
do { do {
token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount); token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
const WordProperty wordProperty = dictionary->getWordProperty( const WordProperty wordProperty = dictionary->getWordProperty(

View file

@ -580,10 +580,12 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH, getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
bigramWord1CodePoints); bigramWord1CodePoints);
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
const int probability = bigramEntry.hasHistoricalInfo() ? const int rawBigramProbability = bigramEntry.hasHistoricalInfo()
ForgettingCurveUtils::decodeProbability( ? ForgettingCurveUtils::decodeProbability(
bigramEntry.getHistoricalInfo(), mHeaderPolicy) : bigramEntry.getHistoricalInfo(), mHeaderPolicy)
bigramEntry.getProbability(); : bigramEntry.getProbability();
const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(),
ptNodeParams.representsBeginningOfSentence(), rawBigramProbability);
ngrams.emplace_back( ngrams.emplace_back(
NgramContext(wordCodePoints.data(), wordCodePoints.size(), NgramContext(wordCodePoints.data(), wordCodePoints.size(),
ptNodeParams.representsBeginningOfSentence()), ptNodeParams.representsBeginningOfSentence()),

View file

@ -140,6 +140,44 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo); return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
} }
std::vector<LanguageModelDictContent::DumppedFullEntryInfo>
LanguageModelDictContent::exportAllNgramEntriesRelatedToWord(
const HeaderPolicy *const headerPolicy, const int wordId) const {
const TrieMap::Result result = mTrieMap.getRoot(wordId);
if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) {
// The word doesn't have any related ngram entries.
return std::vector<DumppedFullEntryInfo>();
}
std::vector<int> prevWordIds = { wordId };
std::vector<DumppedFullEntryInfo> entries;
exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex,
&prevWordIds, &entries);
return entries;
}
void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex,
std::vector<int> *const prevWordIds,
std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
const int wordId = entry.key();
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
if (probabilityEntry.isValid()) {
const WordAttributes wordAttributes = getWordAttributes(
WordIdArrayView(*prevWordIds), wordId, headerPolicy);
outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
wordAttributes, probabilityEntry);
}
if (entry.hasNextLevelMap()) {
prevWordIds->push_back(wordId);
exportAllNgramEntriesRelatedToWordInner(headerPolicy,
entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo);
prevWordIds->pop_back();
}
}
}
bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts, bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy, const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
MutableEntryCounters *const outEntryCounters) { MutableEntryCounters *const outEntryCounters) {
@ -231,24 +269,25 @@ bool LanguageModelDictContent::runGCInner(
} }
int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) { int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
if (prevWordIds.empty()) { int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
return mTrieMap.getRootBitmapEntryIndex(); for (const int wordId : prevWordIds) {
const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex);
if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) {
lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
continue;
} }
const int lastBitmapEntryIndex =
getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
return TrieMap::INVALID_INDEX;
}
const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID);
const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
if (!result.mIsValid) { if (!result.mIsValid) {
if (!mTrieMap.put(oldestPrevWordId, if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo),
ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) { lastBitmapEntryIndex)) {
AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId,
lastBitmapEntryIndex);
return TrieMap::INVALID_INDEX; return TrieMap::INVALID_INDEX;
} }
} }
return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID), lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId,
lastBitmapEntryIndex); lastBitmapEntryIndex);
}
return lastBitmapEntryIndex;
} }
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {

View file

@ -110,6 +110,27 @@ class LanguageModelDictContent {
const bool mHasHistoricalInfo; const bool mHasHistoricalInfo;
}; };
class DumppedFullEntryInfo {
public:
DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
: mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}
const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
int getTargetWordId() const { return mTargetWordId; }
const WordAttributes &getWordAttributes() const { return mWordAttributes; }
const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }
private:
DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);
const std::vector<int> mPrevWordIds;
const int mTargetWordId;
const WordAttributes mWordAttributes;
const ProbabilityEntry mProbabilityEntry;
};
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer, LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
const bool hasHistoricalInfo) const bool hasHistoricalInfo)
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {} : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
@ -151,6 +172,9 @@ class LanguageModelDictContent {
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const; EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
const HeaderPolicy *const headerPolicy, const int wordId) const;
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
MutableEntryCounters *const outEntryCounters) { MutableEntryCounters *const outEntryCounters) {
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
@ -212,6 +236,9 @@ class LanguageModelDictContent {
const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry, const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
const bool isValid, const HistoricalInfo historicalInfo, const bool isValid, const HistoricalInfo historicalInfo,
const HeaderPolicy *const headerPolicy) const; const HeaderPolicy *const headerPolicy) const;
void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
}; };
} // namespace latinime } // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */

View file

@ -491,30 +491,37 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
const int ptNodePos = const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
const ProbabilityEntry probabilityEntry = const LanguageModelDictContent *const languageModelDictContent =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry( mBuffers->getLanguageModelDictContent();
ptNodeParams.getTerminalId()); // Fetch ngram information.
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
// Fetch bigram information.
// TODO: Support n-gram.
std::vector<NgramProperty> ngrams; std::vector<NgramProperty> ngrams;
const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId); int ngramTargetCodePoints[MAX_WORD_LENGTH];
int bigramWord1CodePoints[MAX_WORD_LENGTH]; int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries( int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordIds)) { bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(), for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord(
MAX_WORD_LENGTH, bigramWord1CodePoints); mHeaderPolicy, wordId)) {
const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(),
MAX_WORD_LENGTH, ngramTargetCodePoints);
const WordIdArrayView prevWordIds = entry.getPrevWordIds();
for (size_t i = 0; i < prevWordIds.size(); ++i) {
ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i],
MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]);
ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry(
prevWordIds[i]).representsBeginningOfSentence();
if (ngramPrevWordIsBeginningOfSentense[i]) {
ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker(
ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]);
}
}
const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount,
ngramPrevWordIsBeginningOfSentense, prevWordIds.size());
const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry(); const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry();
const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo(); const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo();
const int probability = ngramProbabilityEntry.hasHistoricalInfo() ? // TODO: Output flags in WordAttributes.
ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) : ngrams.emplace_back(ngramContext,
ngramProbabilityEntry.getProbability(); CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(),
ngrams.emplace_back( entry.getWordAttributes().getProbability(), *historicalInfo);
NgramContext(
wordCodePoints.data(), wordCodePoints.size(),
probabilityEntry.representsBeginningOfSentence()),
CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
probability, *historicalInfo);
} }
// Fetch shortcut information. // Fetch shortcut information.
std::vector<UnigramProperty::ShortcutProperty> shortcuts; std::vector<UnigramProperty::ShortcutProperty> shortcuts;
@ -534,6 +541,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
shortcutProbability); shortcutProbability);
} }
} }
const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(
ptNodeParams.getTerminalId());
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(), const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(), probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(), probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(),

View file

@ -101,6 +101,17 @@ class CharUtils {
return codePointCount + 1; return codePointCount + 1;
} }
// Returns updated code point count.
static AK_FORCE_INLINE int removeBeginningOfSentenceMarker(int *const codePoints,
const int codePointCount) {
if (codePointCount <= 0 || codePoints[0] != CODE_POINT_BEGINNING_OF_SENTENCE) {
return codePointCount;
}
const int newCodePointCount = codePointCount - 1;
memmove(codePoints, codePoints + 1, sizeof(int) * newCodePointCount);
return newCodePointCount;
}
private: private:
DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils); DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils);

View file

@ -653,6 +653,13 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
assertFalse(binaryDictionary.isValidWord("bbb")); assertFalse(binaryDictionary.isValidWord("bbb"));
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb")); assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
if (supportsNgram(toFormatVersion)) {
onInputWordWithPrevWords(binaryDictionary, "xyz", true, "abc", "aaa");
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
}
assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion()); assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion());
assertTrue(binaryDictionary.migrateTo(toFormatVersion)); assertTrue(binaryDictionary.migrateTo(toFormatVersion));
assertTrue(binaryDictionary.isValidDictionary()); assertTrue(binaryDictionary.isValidDictionary());
@ -666,6 +673,14 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb")); assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa"); onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa");
assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb")); assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb"));
if (supportsNgram(toFormatVersion)) {
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
}
binaryDictionary.close(); binaryDictionary.close();
} }