Support ngram entry migration.

Bug: 14425059
Change-Id: I98cb9fa303af2d93a0a3512e8732231c564e3c5d
This commit is contained in:
Keisuke Kuroyanagi 2014-10-22 11:31:16 +09:00
parent 0b8bb0c21b
commit c9865785f4
7 changed files with 146 additions and 43 deletions

View file

@ -629,8 +629,7 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j
}
} while (token != 0);
// Add bigrams.
// TODO: Support ngrams.
// Add ngrams.
do {
token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
const WordProperty wordProperty = dictionary->getWordProperty(

View file

@ -580,10 +580,12 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
bigramWord1CodePoints);
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
const int probability = bigramEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
bigramEntry.getHistoricalInfo(), mHeaderPolicy) :
bigramEntry.getProbability();
const int rawBigramProbability = bigramEntry.hasHistoricalInfo()
? ForgettingCurveUtils::decodeProbability(
bigramEntry.getHistoricalInfo(), mHeaderPolicy)
: bigramEntry.getProbability();
const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(),
ptNodeParams.representsBeginningOfSentence(), rawBigramProbability);
ngrams.emplace_back(
NgramContext(wordCodePoints.data(), wordCodePoints.size(),
ptNodeParams.representsBeginningOfSentence()),

View file

@ -140,6 +140,44 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}
std::vector<LanguageModelDictContent::DumppedFullEntryInfo>
LanguageModelDictContent::exportAllNgramEntriesRelatedToWord(
const HeaderPolicy *const headerPolicy, const int wordId) const {
const TrieMap::Result result = mTrieMap.getRoot(wordId);
if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) {
// The word doesn't have any related ngram entries.
return std::vector<DumppedFullEntryInfo>();
}
std::vector<int> prevWordIds = { wordId };
std::vector<DumppedFullEntryInfo> entries;
exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex,
&prevWordIds, &entries);
return entries;
}
void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex,
std::vector<int> *const prevWordIds,
std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
const int wordId = entry.key();
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
if (probabilityEntry.isValid()) {
const WordAttributes wordAttributes = getWordAttributes(
WordIdArrayView(*prevWordIds), wordId, headerPolicy);
outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
wordAttributes, probabilityEntry);
}
if (entry.hasNextLevelMap()) {
prevWordIds->push_back(wordId);
exportAllNgramEntriesRelatedToWordInner(headerPolicy,
entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo);
prevWordIds->pop_back();
}
}
}
bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
MutableEntryCounters *const outEntryCounters) {
@ -231,24 +269,25 @@ bool LanguageModelDictContent::runGCInner(
}
int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
if (prevWordIds.empty()) {
return mTrieMap.getRootBitmapEntryIndex();
}
const int lastBitmapEntryIndex =
getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
return TrieMap::INVALID_INDEX;
}
const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID);
const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
if (!result.mIsValid) {
if (!mTrieMap.put(oldestPrevWordId,
ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) {
return TrieMap::INVALID_INDEX;
int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
for (const int wordId : prevWordIds) {
const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex);
if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) {
lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
continue;
}
if (!result.mIsValid) {
if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo),
lastBitmapEntryIndex)) {
AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId,
lastBitmapEntryIndex);
return TrieMap::INVALID_INDEX;
}
}
lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId,
lastBitmapEntryIndex);
}
return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID),
lastBitmapEntryIndex);
return lastBitmapEntryIndex;
}
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {

View file

@ -110,6 +110,27 @@ class LanguageModelDictContent {
const bool mHasHistoricalInfo;
};
class DumppedFullEntryInfo {
public:
DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
: mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}
const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
int getTargetWordId() const { return mTargetWordId; }
const WordAttributes &getWordAttributes() const { return mWordAttributes; }
const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }
private:
DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);
const std::vector<int> mPrevWordIds;
const int mTargetWordId;
const WordAttributes mWordAttributes;
const ProbabilityEntry mProbabilityEntry;
};
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
const bool hasHistoricalInfo)
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
@ -151,6 +172,9 @@ class LanguageModelDictContent {
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
const HeaderPolicy *const headerPolicy, const int wordId) const;
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
MutableEntryCounters *const outEntryCounters) {
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
@ -212,6 +236,9 @@ class LanguageModelDictContent {
const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
const bool isValid, const HistoricalInfo historicalInfo,
const HeaderPolicy *const headerPolicy) const;
void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */

View file

@ -491,30 +491,37 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
ptNodeParams.getTerminalId());
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
// Fetch bigram information.
// TODO: Support n-gram.
const LanguageModelDictContent *const languageModelDictContent =
mBuffers->getLanguageModelDictContent();
// Fetch ngram information.
std::vector<NgramProperty> ngrams;
const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
int bigramWord1CodePoints[MAX_WORD_LENGTH];
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
prevWordIds)) {
const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
MAX_WORD_LENGTH, bigramWord1CodePoints);
int ngramTargetCodePoints[MAX_WORD_LENGTH];
int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord(
mHeaderPolicy, wordId)) {
const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(),
MAX_WORD_LENGTH, ngramTargetCodePoints);
const WordIdArrayView prevWordIds = entry.getPrevWordIds();
for (size_t i = 0; i < prevWordIds.size(); ++i) {
ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i],
MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]);
ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry(
prevWordIds[i]).representsBeginningOfSentence();
if (ngramPrevWordIsBeginningOfSentense[i]) {
ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker(
ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]);
}
}
const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount,
ngramPrevWordIsBeginningOfSentense, prevWordIds.size());
const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry();
const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo();
const int probability = ngramProbabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) :
ngramProbabilityEntry.getProbability();
ngrams.emplace_back(
NgramContext(
wordCodePoints.data(), wordCodePoints.size(),
probabilityEntry.representsBeginningOfSentence()),
CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
probability, *historicalInfo);
// TODO: Output flags in WordAttributes.
ngrams.emplace_back(ngramContext,
CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(),
entry.getWordAttributes().getProbability(), *historicalInfo);
}
// Fetch shortcut information.
std::vector<UnigramProperty::ShortcutProperty> shortcuts;
@ -534,6 +541,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
shortcutProbability);
}
}
const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(
ptNodeParams.getTerminalId());
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(),

View file

@ -101,6 +101,17 @@ class CharUtils {
return codePointCount + 1;
}
// Returns updated code point count.
static AK_FORCE_INLINE int removeBeginningOfSentenceMarker(int *const codePoints,
const int codePointCount) {
if (codePointCount <= 0 || codePoints[0] != CODE_POINT_BEGINNING_OF_SENTENCE) {
return codePointCount;
}
const int newCodePointCount = codePointCount - 1;
memmove(codePoints, codePoints + 1, sizeof(int) * newCodePointCount);
return newCodePointCount;
}
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils);

View file

@ -653,6 +653,13 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
assertFalse(binaryDictionary.isValidWord("bbb"));
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
if (supportsNgram(toFormatVersion)) {
onInputWordWithPrevWords(binaryDictionary, "xyz", true, "abc", "aaa");
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
}
assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion());
assertTrue(binaryDictionary.migrateTo(toFormatVersion));
assertTrue(binaryDictionary.isValidDictionary());
@ -666,6 +673,14 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa");
assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb"));
if (supportsNgram(toFormatVersion)) {
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
}
binaryDictionary.close();
}