Support ngram entry migration.
Bug: 14425059 Change-Id: I98cb9fa303af2d93a0a3512e8732231c564e3c5d
This commit is contained in:
parent
0b8bb0c21b
commit
c9865785f4
7 changed files with 146 additions and 43 deletions
|
@ -629,8 +629,7 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j
|
|||
}
|
||||
} while (token != 0);
|
||||
|
||||
// Add bigrams.
|
||||
// TODO: Support ngrams.
|
||||
// Add ngrams.
|
||||
do {
|
||||
token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount);
|
||||
const WordProperty wordProperty = dictionary->getWordProperty(
|
||||
|
|
|
@ -580,10 +580,12 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
|
|||
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
|
||||
bigramWord1CodePoints);
|
||||
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
|
||||
const int probability = bigramEntry.hasHistoricalInfo() ?
|
||||
ForgettingCurveUtils::decodeProbability(
|
||||
bigramEntry.getHistoricalInfo(), mHeaderPolicy) :
|
||||
bigramEntry.getProbability();
|
||||
const int rawBigramProbability = bigramEntry.hasHistoricalInfo()
|
||||
? ForgettingCurveUtils::decodeProbability(
|
||||
bigramEntry.getHistoricalInfo(), mHeaderPolicy)
|
||||
: bigramEntry.getProbability();
|
||||
const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(),
|
||||
ptNodeParams.representsBeginningOfSentence(), rawBigramProbability);
|
||||
ngrams.emplace_back(
|
||||
NgramContext(wordCodePoints.data(), wordCodePoints.size(),
|
||||
ptNodeParams.representsBeginningOfSentence()),
|
||||
|
|
|
@ -140,6 +140,44 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
|
|||
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
|
||||
}
|
||||
|
||||
std::vector<LanguageModelDictContent::DumppedFullEntryInfo>
|
||||
LanguageModelDictContent::exportAllNgramEntriesRelatedToWord(
|
||||
const HeaderPolicy *const headerPolicy, const int wordId) const {
|
||||
const TrieMap::Result result = mTrieMap.getRoot(wordId);
|
||||
if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
||||
// The word doesn't have any related ngram entries.
|
||||
return std::vector<DumppedFullEntryInfo>();
|
||||
}
|
||||
std::vector<int> prevWordIds = { wordId };
|
||||
std::vector<DumppedFullEntryInfo> entries;
|
||||
exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex,
|
||||
&prevWordIds, &entries);
|
||||
return entries;
|
||||
}
|
||||
|
||||
void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
|
||||
const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex,
|
||||
std::vector<int> *const prevWordIds,
|
||||
std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const {
|
||||
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
|
||||
const int wordId = entry.key();
|
||||
const ProbabilityEntry probabilityEntry =
|
||||
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
|
||||
if (probabilityEntry.isValid()) {
|
||||
const WordAttributes wordAttributes = getWordAttributes(
|
||||
WordIdArrayView(*prevWordIds), wordId, headerPolicy);
|
||||
outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
|
||||
wordAttributes, probabilityEntry);
|
||||
}
|
||||
if (entry.hasNextLevelMap()) {
|
||||
prevWordIds->push_back(wordId);
|
||||
exportAllNgramEntriesRelatedToWordInner(headerPolicy,
|
||||
entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo);
|
||||
prevWordIds->pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCounts,
|
||||
const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
|
||||
MutableEntryCounters *const outEntryCounters) {
|
||||
|
@ -231,24 +269,25 @@ bool LanguageModelDictContent::runGCInner(
|
|||
}
|
||||
|
||||
int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
|
||||
if (prevWordIds.empty()) {
|
||||
return mTrieMap.getRootBitmapEntryIndex();
|
||||
int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
|
||||
for (const int wordId : prevWordIds) {
|
||||
const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex);
|
||||
if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) {
|
||||
lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
|
||||
continue;
|
||||
}
|
||||
const int lastBitmapEntryIndex =
|
||||
getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
|
||||
if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
||||
return TrieMap::INVALID_INDEX;
|
||||
}
|
||||
const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID);
|
||||
const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
|
||||
if (!result.mIsValid) {
|
||||
if (!mTrieMap.put(oldestPrevWordId,
|
||||
ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) {
|
||||
if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo),
|
||||
lastBitmapEntryIndex)) {
|
||||
AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId,
|
||||
lastBitmapEntryIndex);
|
||||
return TrieMap::INVALID_INDEX;
|
||||
}
|
||||
}
|
||||
return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID),
|
||||
lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId,
|
||||
lastBitmapEntryIndex);
|
||||
}
|
||||
return lastBitmapEntryIndex;
|
||||
}
|
||||
|
||||
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
|
||||
|
|
|
@ -110,6 +110,27 @@ class LanguageModelDictContent {
|
|||
const bool mHasHistoricalInfo;
|
||||
};
|
||||
|
||||
class DumppedFullEntryInfo {
|
||||
public:
|
||||
DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
|
||||
const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
|
||||
: mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
|
||||
mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}
|
||||
|
||||
const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
|
||||
int getTargetWordId() const { return mTargetWordId; }
|
||||
const WordAttributes &getWordAttributes() const { return mWordAttributes; }
|
||||
const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }
|
||||
|
||||
private:
|
||||
DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);
|
||||
|
||||
const std::vector<int> mPrevWordIds;
|
||||
const int mTargetWordId;
|
||||
const WordAttributes mWordAttributes;
|
||||
const ProbabilityEntry mProbabilityEntry;
|
||||
};
|
||||
|
||||
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
|
||||
const bool hasHistoricalInfo)
|
||||
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
|
||||
|
@ -151,6 +172,9 @@ class LanguageModelDictContent {
|
|||
|
||||
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
|
||||
|
||||
std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
|
||||
const HeaderPolicy *const headerPolicy, const int wordId) const;
|
||||
|
||||
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
|
||||
MutableEntryCounters *const outEntryCounters) {
|
||||
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
|
||||
|
@ -212,6 +236,9 @@ class LanguageModelDictContent {
|
|||
const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
|
||||
const bool isValid, const HistoricalInfo historicalInfo,
|
||||
const HeaderPolicy *const headerPolicy) const;
|
||||
void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
|
||||
const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
|
||||
std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
|
||||
};
|
||||
} // namespace latinime
|
||||
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
|
||||
|
|
|
@ -491,30 +491,37 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
|
|||
const int ptNodePos =
|
||||
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
|
||||
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
||||
const ProbabilityEntry probabilityEntry =
|
||||
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
|
||||
ptNodeParams.getTerminalId());
|
||||
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
|
||||
// Fetch bigram information.
|
||||
// TODO: Support n-gram.
|
||||
const LanguageModelDictContent *const languageModelDictContent =
|
||||
mBuffers->getLanguageModelDictContent();
|
||||
// Fetch ngram information.
|
||||
std::vector<NgramProperty> ngrams;
|
||||
const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
|
||||
int bigramWord1CodePoints[MAX_WORD_LENGTH];
|
||||
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
|
||||
prevWordIds)) {
|
||||
const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
|
||||
MAX_WORD_LENGTH, bigramWord1CodePoints);
|
||||
int ngramTargetCodePoints[MAX_WORD_LENGTH];
|
||||
int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
|
||||
int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
|
||||
for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord(
|
||||
mHeaderPolicy, wordId)) {
|
||||
const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(),
|
||||
MAX_WORD_LENGTH, ngramTargetCodePoints);
|
||||
const WordIdArrayView prevWordIds = entry.getPrevWordIds();
|
||||
for (size_t i = 0; i < prevWordIds.size(); ++i) {
|
||||
ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i],
|
||||
MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]);
|
||||
ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry(
|
||||
prevWordIds[i]).representsBeginningOfSentence();
|
||||
if (ngramPrevWordIsBeginningOfSentense[i]) {
|
||||
ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker(
|
||||
ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]);
|
||||
}
|
||||
}
|
||||
const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount,
|
||||
ngramPrevWordIsBeginningOfSentense, prevWordIds.size());
|
||||
const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry();
|
||||
const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo();
|
||||
const int probability = ngramProbabilityEntry.hasHistoricalInfo() ?
|
||||
ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) :
|
||||
ngramProbabilityEntry.getProbability();
|
||||
ngrams.emplace_back(
|
||||
NgramContext(
|
||||
wordCodePoints.data(), wordCodePoints.size(),
|
||||
probabilityEntry.representsBeginningOfSentence()),
|
||||
CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
|
||||
probability, *historicalInfo);
|
||||
// TODO: Output flags in WordAttributes.
|
||||
ngrams.emplace_back(ngramContext,
|
||||
CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(),
|
||||
entry.getWordAttributes().getProbability(), *historicalInfo);
|
||||
}
|
||||
// Fetch shortcut information.
|
||||
std::vector<UnigramProperty::ShortcutProperty> shortcuts;
|
||||
|
@ -534,6 +541,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
|
|||
shortcutProbability);
|
||||
}
|
||||
}
|
||||
const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(
|
||||
ptNodeParams.getTerminalId());
|
||||
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
|
||||
const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
|
||||
probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
|
||||
probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(),
|
||||
|
|
|
@ -101,6 +101,17 @@ class CharUtils {
|
|||
return codePointCount + 1;
|
||||
}
|
||||
|
||||
// Returns updated code point count.
|
||||
static AK_FORCE_INLINE int removeBeginningOfSentenceMarker(int *const codePoints,
|
||||
const int codePointCount) {
|
||||
if (codePointCount <= 0 || codePoints[0] != CODE_POINT_BEGINNING_OF_SENTENCE) {
|
||||
return codePointCount;
|
||||
}
|
||||
const int newCodePointCount = codePointCount - 1;
|
||||
memmove(codePoints, codePoints + 1, sizeof(int) * newCodePointCount);
|
||||
return newCodePointCount;
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils);
|
||||
|
||||
|
|
|
@ -653,6 +653,13 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
|
|||
assertFalse(binaryDictionary.isValidWord("bbb"));
|
||||
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
|
||||
|
||||
if (supportsNgram(toFormatVersion)) {
|
||||
onInputWordWithPrevWords(binaryDictionary, "xyz", true, "abc", "aaa");
|
||||
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
|
||||
onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
|
||||
assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
|
||||
}
|
||||
|
||||
assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion());
|
||||
assertTrue(binaryDictionary.migrateTo(toFormatVersion));
|
||||
assertTrue(binaryDictionary.isValidDictionary());
|
||||
|
@ -666,6 +673,14 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
|
|||
assertFalse(isValidBigram(binaryDictionary, "aaa", "bbb"));
|
||||
onInputWordWithPrevWord(binaryDictionary, "bbb", false /* isValidWord */, "aaa");
|
||||
assertTrue(isValidBigram(binaryDictionary, "aaa", "bbb"));
|
||||
|
||||
if (supportsNgram(toFormatVersion)) {
|
||||
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "xyz"));
|
||||
assertFalse(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
|
||||
onInputWordWithPrevWords(binaryDictionary, "def", false, "abc", "aaa");
|
||||
assertTrue(isValidTrigram(binaryDictionary, "aaa", "abc", "def"));
|
||||
}
|
||||
|
||||
binaryDictionary.close();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue