Support ngram entry migration.
Bug: 14425059 Change-Id: I98cb9fa303af2d93a0a3512e8732231c564e3c5d
This commit is contained in:
parent
0b8bb0c21b
commit
c9865785f4
7 changed files with 146 additions and 43 deletions
|
@ -629,8 +629,7 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j
|
||||||
}
|
}
|
||||||
} while (token != 0);
|
} while (token != 0);
|
||||||
|
|
||||||
// Add bigrams.
|
// Add ngrams.
|
||||||
// TODO: Support ngrams.
|
|
||||||
do {
|
do {
|
||||||
token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
|
token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
|
||||||
const WordProperty wordProperty = dictionary->getWordProperty(
|
const WordProperty wordProperty = dictionary->getWordProperty(
|
||||||
|
|
|
@ -580,10 +580,12 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
|
||||||
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
|
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
|
||||||
bigramWord1CodePoints);
|
bigramWord1CodePoints);
|
||||||
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
|
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
|
||||||
const int probability = bigramEntry.hasHistoricalInfo() ?
|
const int rawBigramProbability = bigramEntry.hasHistoricalInfo()
|
||||||
ForgettingCurveUtils::decodeProbability(
|
? ForgettingCurveUtils::decodeProbability(
|
||||||
bigramEntry.getHistoricalInfo(), mHeaderPolicy) :
|
bigramEntry.getHistoricalInfo(), mHeaderPolicy)
|
||||||
bigramEntry.getProbability();
|
: bigramEntry.getProbability();
|
||||||
|
const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(),
|
||||||
|
ptNodeParams.representsBeginningOfSentence(), rawBigramProbability);
|
||||||
ngrams.emplace_back(
|
ngrams.emplace_back(
|
||||||
NgramContext(wordCodePoints.data(), wordCodePoints.size(),
|
NgramContext(wordCodePoints.data(), wordCodePoints.size(),
|
||||||
ptNodeParams.representsBeginningOfSentence()),
|
ptNodeParams.representsBeginningOfSentence()),
|
||||||
|
|
|
@ -140,6 +140,44 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
|
||||||
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
|
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<LanguageModelDictContent::DumppedFullEntryInfo>
|
||||||
|
LanguageModelDictContent::exportAllNgramEntriesRelatedToWord(
|
||||||
|
const HeaderPolicy *const headerPolicy, const int wordId) const {
|
||||||
|
const TrieMap::Result result = mTrieMap.getRoot(wordId);
|
||||||
|
if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
||||||
|
// The word doesn't have any related ngram entries.
|
||||||
|
return std::vector<DumppedFullEntryInfo>();
|
||||||
|
}
|
||||||
|
std::vector<int> prevWordIds = { wordId };
|
||||||
|
std::vector<DumppedFullEntryInfo> entries;
|
||||||
|
exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex,
|
||||||
|
&prevWordIds, &entries);
|
||||||
|
return entries;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
|
||||||
|
const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex,
|
||||||
|
std::vector<int> *const prevWordIds,
|
||||||
|
std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const {
|
||||||
|
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
|
||||||
|
const int wordId = entry.key();
|
||||||
|
const ProbabilityEntry probabilityEntry =
|
||||||
|
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
|
||||||
|
if (probabilityEntry.isValid()) {
|
||||||
|
const WordAttributes wordAttributes = getWordAttributes(
|
||||||
|
WordIdArrayView(*prevWordIds), wordId, headerPolicy);
|
||||||
|
outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
|
||||||
|
wordAttributes, probabilityEntry);
|
||||||
|
}
|
||||||
|
if (entry.hasNextLevelMap()) {
|
||||||
|
prevWordIds->push_back(wordId);
|
||||||
|
exportAllNgramEntriesRelatedToWordInner(headerPolicy,
|
||||||
|
entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo);
|
||||||
|
prevWordIds->pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCounts,
|
bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCounts,
|
||||||
const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
|
const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
|
||||||
MutableEntryCounters *const outEntryCounters) {
|
MutableEntryCounters *const outEntryCounters) {
|
||||||
|
@ -231,24 +269,25 @@ bool LanguageModelDictContent::runGCInner(
|
||||||
}
|
}
|
||||||
|
|
||||||
int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
|
int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
|
||||||
if (prevWordIds.empty()) {
|
int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
|
||||||
return mTrieMap.getRootBitmapEntryIndex();
|
for (const int wordId : prevWordIds) {
|
||||||
|
const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex);
|
||||||
|
if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) {
|
||||||
|
lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
const int lastBitmapEntryIndex =
|
|
||||||
getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
|
|
||||||
if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
|
||||||
return TrieMap::INVALID_INDEX;
|
|
||||||
}
|
|
||||||
const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID);
|
|
||||||
const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
|
|
||||||
if (!result.mIsValid) {
|
if (!result.mIsValid) {
|
||||||
if (!mTrieMap.put(oldestPrevWordId,
|
if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo),
|
||||||
ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) {
|
lastBitmapEntryIndex)) {
|
||||||
|
AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId,
|
||||||
|
lastBitmapEntryIndex);
|
||||||
return TrieMap::INVALID_INDEX;
|
return TrieMap::INVALID_INDEX;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID),
|
lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId,
|
||||||
lastBitmapEntryIndex);
|
lastBitmapEntryIndex);
|
||||||
|
}
|
||||||
|
return lastBitmapEntryIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
|
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
|
||||||
|
|
|
@ -110,6 +110,27 @@ class LanguageModelDictContent {
|
||||||
const bool mHasHistoricalInfo;
|
const bool mHasHistoricalInfo;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class DumppedFullEntryInfo {
|
||||||
|
public:
|
||||||
|
DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
|
||||||
|
const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
|
||||||
|
: mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
|
||||||
|
mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}
|
||||||
|
|
||||||
|
const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
|
||||||
|
int getTargetWordId() const { return mTargetWordId; }
|
||||||
|
const WordAttributes &getWordAttributes() const { return mWordAttributes; }
|
||||||
|
const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);
|
||||||
|
|
||||||
|
const std::vector<int> mPrevWordIds;
|
||||||
|
const int mTargetWordId;
|
||||||
|
const WordAttributes mWordAttributes;
|
||||||
|
const ProbabilityEntry mProbabilityEntry;
|
||||||
|
};
|
||||||
|
|
||||||
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
|
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
|
||||||
const bool hasHistoricalInfo)
|
const bool hasHistoricalInfo)
|
||||||
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
|
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
|
||||||
|
@ -151,6 +172,9 @@ class LanguageModelDictContent {
|
||||||
|
|
||||||
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
|
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
|
||||||
|
|
||||||
|
std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
|
||||||
|
const HeaderPolicy *const headerPolicy, const int wordId) const;
|
||||||
|
|
||||||
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
|
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
|
||||||
MutableEntryCounters *const outEntryCounters) {
|
MutableEntryCounters *const outEntryCounters) {
|
||||||
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
|
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
|
||||||
|
@ -212,6 +236,9 @@ class LanguageModelDictContent {
|
||||||
const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
|
const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
|
||||||
const bool isValid, const HistoricalInfo historicalInfo,
|
const bool isValid, const HistoricalInfo historicalInfo,
|
||||||
const HeaderPolicy *const headerPolicy) const;
|
const HeaderPolicy *const headerPolicy) const;
|
||||||
|
void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
|
||||||
|
const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
|
||||||
|
std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
|
||||||
};
|
};
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
|
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
|
||||||
|
|
|
@ -491,30 +491,37 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
|
||||||
const int ptNodePos =
|
const int ptNodePos =
|
||||||
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
|
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
|
||||||
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
||||||
const ProbabilityEntry probabilityEntry =
|
const LanguageModelDictContent *const languageModelDictContent =
|
||||||
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
|
mBuffers->getLanguageModelDictContent();
|
||||||
ptNodeParams.getTerminalId());
|
// Fetch ngram information.
|
||||||
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
|
|
||||||
// Fetch bigram information.
|
|
||||||
// TODO: Support n-gram.
|
|
||||||
std::vector<NgramProperty> ngrams;
|
std::vector<NgramProperty> ngrams;
|
||||||
const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
|
int ngramTargetCodePoints[MAX_WORD_LENGTH];
|
||||||
int bigramWord1CodePoints[MAX_WORD_LENGTH];
|
int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
|
||||||
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
|
int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||||
prevWordIds)) {
|
bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||||
const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
|
for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord(
|
||||||
MAX_WORD_LENGTH, bigramWord1CodePoints);
|
mHeaderPolicy, wordId)) {
|
||||||
|
const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(),
|
||||||
|
MAX_WORD_LENGTH, ngramTargetCodePoints);
|
||||||
|
const WordIdArrayView prevWordIds = entry.getPrevWordIds();
|
||||||
|
for (size_t i = 0; i < prevWordIds.size(); ++i) {
|
||||||
|
ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i],
|
||||||
|
MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]);
|
||||||
|
ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry(
|
||||||
|
prevWordIds[i]).representsBeginningOfSentence();
|
||||||
|
if (ngramPrevWordIsBeginningOfSentense[i]) {
|
||||||
|
ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker(
|
||||||
|
ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount,
|
||||||
|
ngramPrevWordIsBeginningOfSentense, prevWordIds.size());
|
||||||
const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry();
|
const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry();
|
||||||
const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo();
|
const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo();
|
||||||
const int probability = ngramProbabilityEntry.hasHistoricalInfo() ?
|
// TODO: Output flags in WordAttributes.
|
||||||
ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) :
|
ngrams.emplace_back(ngramContext,
|
||||||
ngramProbabilityEntry.getProbability();
|
CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(),
|
||||||
ngrams.emplace_back(
|
entry.getWordAttributes().getProbability(), *historicalInfo);
|
||||||
NgramContext(
|
|
||||||
wordCodePoints.data(), wordCodePoints.size(),
|
|
||||||
probabilityEntry.representsBeginningOfSentence()),
|
|
||||||
CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
|
|
||||||
probability, *historicalInfo);
|
|
||||||
}
|
}
|
||||||
// Fetch shortcut information.
|
// Fetch shortcut information.
|
||||||
std::vector<UnigramProperty::ShortcutProperty> shortcuts;
|
std::vector<UnigramProperty::ShortcutProperty> shortcuts;
|
||||||
|
@ -534,6 +541,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
|
||||||
shortcutProbability);
|
shortcutProbability);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(
|
||||||
|
ptNodeParams.getTerminalId());
|
||||||
|
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
|
||||||
const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
|
const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
|
||||||
probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
|
probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
|
||||||
probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(),
|
probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(),
|
||||||
|
|
|
@ -101,6 +101,17 @@ class CharUtils {
|
||||||
return codePointCount + 1;
|
return codePointCount + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns updated code point count.
|
||||||
|
static AK_FORCE_INLINE int removeBeginningOfSentenceMarker(int *const codePoints,
|
||||||
|
const int codePointCount) {
|
||||||
|
if (codePointCount <= 0 || codePoints[0] != CODE_POINT_BEGINNING_OF_SENTENCE) {
|
||||||
|
return codePointCount;
|
||||||
|
}
|
||||||
|
const int newCodePointCount = codePointCount - 1;
|
||||||
|
memmove(codePoints, codePoints + 1, sizeof(int) * newCodePointCount);
|
||||||
|
return newCodePointCount;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils);
|
DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils);
|
||||||
|
|
||||||
|
|
|
@ -653,6 +653,13 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
|
||||||
assertFalse(binaryDictionary.isValidWord("bbb"));
|
assertFalse(binaryDictionary.isValidWord("bbb"));
|
||||||
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
|
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
|
||||||
|
|
||||||
|
if (supportsNgram(toFormatVersion)) {
|
||||||
|
onInputWordWithPrevWords(binaryDictionary, "xyz", true, "abc", "aaa");
|
||||||
|
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
|
||||||
|
onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
|
||||||
|
assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
|
||||||
|
}
|
||||||
|
|
||||||
assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion());
|
assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion());
|
||||||
assertTrue(binaryDictionary.migrateTo(toFormatVersion));
|
assertTrue(binaryDictionary.migrateTo(toFormatVersion));
|
||||||
assertTrue(binaryDictionary.isValidDictionary());
|
assertTrue(binaryDictionary.isValidDictionary());
|
||||||
|
@ -666,6 +673,14 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
|
||||||
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
|
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
|
||||||
onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa");
|
onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa");
|
||||||
assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb"));
|
assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb"));
|
||||||
|
|
||||||
|
if (supportsNgram(toFormatVersion)) {
|
||||||
|
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
|
||||||
|
assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
|
||||||
|
onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
|
||||||
|
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
|
||||||
|
}
|
||||||
|
|
||||||
binaryDictionary.close();
|
binaryDictionary.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue