2014-08-01 11:19:16 +00:00
|
|
|
/*
|
|
|
|
* Copyright (C) 2014, The Android Open Source Project
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
|
|
|
|
|
2014-08-21 03:48:24 +00:00
|
|
|
#include <algorithm>
|
|
|
|
#include <cstring>
|
|
|
|
|
2014-08-22 11:07:54 +00:00
|
|
|
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
|
|
|
|
|
2014-08-01 11:19:16 +00:00
|
|
|
namespace latinime {
|
|
|
|
|
2014-08-27 11:04:39 +00:00
|
|
|
const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
|
|
|
|
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
|
|
|
|
|
2014-08-01 11:19:16 +00:00
|
|
|
bool LanguageModelDictContent::save(FILE *const file) const {
|
|
|
|
return mTrieMap.save(file);
|
|
|
|
}
|
|
|
|
|
2014-08-05 03:38:55 +00:00
|
|
|
bool LanguageModelDictContent::runGC(
|
|
|
|
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
|
|
|
|
const LanguageModelDictContent *const originalContent,
|
|
|
|
int *const outNgramCount) {
|
|
|
|
return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
|
|
|
|
0 /* nextLevelBitmapEntryIndex */, outNgramCount);
|
|
|
|
}
|
|
|
|
|
2014-09-10 10:51:12 +00:00
|
|
|
int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
|
2014-09-12 11:17:41 +00:00
|
|
|
const int wordId, const HeaderPolicy *const headerPolicy) const {
|
2014-09-10 10:51:12 +00:00
|
|
|
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
|
|
|
|
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
|
|
|
|
int maxLevel = 0;
|
|
|
|
for (size_t i = 0; i < prevWordIds.size(); ++i) {
|
|
|
|
const int nextBitmapEntryIndex =
|
|
|
|
mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
|
|
|
|
if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
maxLevel = i + 1;
|
|
|
|
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = maxLevel; i >= 0; --i) {
|
|
|
|
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
|
|
|
|
if (!result.mIsValid) {
|
|
|
|
continue;
|
|
|
|
}
|
2014-09-12 11:17:41 +00:00
|
|
|
const ProbabilityEntry probabilityEntry =
|
|
|
|
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
|
2014-09-10 10:51:12 +00:00
|
|
|
if (mHasHistoricalInfo) {
|
2014-09-12 11:17:41 +00:00
|
|
|
const int probability = ForgettingCurveUtils::decodeProbability(
|
|
|
|
probabilityEntry.getHistoricalInfo(), headerPolicy)
|
|
|
|
+ ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */);
|
|
|
|
return std::min(probability, MAX_PROBABILITY);
|
2014-09-10 10:51:12 +00:00
|
|
|
} else {
|
2014-09-12 11:17:41 +00:00
|
|
|
return probabilityEntry.getProbability();
|
2014-09-10 10:51:12 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
// Cannot find the word.
|
|
|
|
return NOT_A_PROBABILITY;
|
|
|
|
}
|
|
|
|
|
2014-08-05 05:13:07 +00:00
|
|
|
ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
|
2014-08-05 03:38:55 +00:00
|
|
|
const WordIdArrayView prevWordIds, const int wordId) const {
|
2014-08-05 05:51:11 +00:00
|
|
|
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
|
|
|
|
if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
2014-08-05 03:38:55 +00:00
|
|
|
return ProbabilityEntry();
|
|
|
|
}
|
2014-08-05 05:51:11 +00:00
|
|
|
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex);
|
2014-08-05 03:38:55 +00:00
|
|
|
if (!result.mIsValid) {
|
|
|
|
// Not found.
|
|
|
|
return ProbabilityEntry();
|
|
|
|
}
|
|
|
|
return ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
|
|
|
|
}
|
|
|
|
|
2014-08-05 05:13:07 +00:00
|
|
|
bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds,
|
2014-08-15 10:55:07 +00:00
|
|
|
const int wordId, const ProbabilityEntry *const probabilityEntry) {
|
|
|
|
if (wordId == Ver4DictConstants::NOT_A_TERMINAL_ID) {
|
|
|
|
return false;
|
|
|
|
}
|
2014-08-12 11:32:42 +00:00
|
|
|
const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds);
|
2014-08-05 05:51:11 +00:00
|
|
|
if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
2014-08-05 03:38:55 +00:00
|
|
|
return false;
|
|
|
|
}
|
2014-08-15 10:55:07 +00:00
|
|
|
return mTrieMap.put(wordId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex);
|
2014-08-18 03:34:48 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView prevWordIds,
|
|
|
|
const int wordId) {
|
|
|
|
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
|
|
|
|
if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
|
|
|
// Cannot find bitmap entry for the probability entry. The entry doesn't exist.
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return mTrieMap.remove(wordId, bitmapEntryIndex);
|
2014-08-05 03:38:55 +00:00
|
|
|
}
|
|
|
|
|
2014-08-26 03:01:08 +00:00
|
|
|
LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries(
|
|
|
|
const WordIdArrayView prevWordIds) const {
|
|
|
|
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
|
|
|
|
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
|
|
|
|
}
|
|
|
|
|
2014-08-21 03:48:24 +00:00
|
|
|
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
|
2014-08-27 11:04:39 +00:00
|
|
|
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy,
|
|
|
|
int *const outEntryCounts) {
|
2014-08-21 03:48:24 +00:00
|
|
|
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
|
|
|
if (entryCounts[i] <= maxEntryCounts[i]) {
|
2014-08-27 11:04:39 +00:00
|
|
|
outEntryCounts[i] = entryCounts[i];
|
2014-08-21 03:48:24 +00:00
|
|
|
continue;
|
|
|
|
}
|
2014-08-27 11:04:39 +00:00
|
|
|
if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i,
|
|
|
|
&outEntryCounts[i])) {
|
2014-08-21 03:48:24 +00:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2014-08-05 03:38:55 +00:00
|
|
|
bool LanguageModelDictContent::runGCInner(
|
|
|
|
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
|
|
|
|
const TrieMap::TrieMapRange trieMapRange,
|
|
|
|
const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
|
|
|
|
for (auto &entry : trieMapRange) {
|
|
|
|
const auto it = terminalIdMap->find(entry.key());
|
|
|
|
if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
|
|
|
|
// The word has been removed.
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if (outNgramCount) {
|
|
|
|
*outNgramCount += 1;
|
|
|
|
}
|
|
|
|
if (entry.hasNextLevelMap()) {
|
|
|
|
if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
|
|
|
|
mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex),
|
|
|
|
outNgramCount)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2014-08-12 11:32:42 +00:00
|
|
|
int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
|
|
|
|
if (prevWordIds.empty()) {
|
|
|
|
return mTrieMap.getRootBitmapEntryIndex();
|
|
|
|
}
|
|
|
|
const int lastBitmapEntryIndex =
|
|
|
|
getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
|
|
|
|
if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
|
|
|
|
return TrieMap::INVALID_INDEX;
|
|
|
|
}
|
2014-09-17 12:16:31 +00:00
|
|
|
const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID);
|
2014-09-16 09:10:56 +00:00
|
|
|
const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
|
|
|
|
if (!result.mIsValid) {
|
|
|
|
if (!mTrieMap.put(oldestPrevWordId,
|
|
|
|
ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) {
|
|
|
|
return TrieMap::INVALID_INDEX;
|
|
|
|
}
|
|
|
|
}
|
2014-09-17 12:16:31 +00:00
|
|
|
return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID),
|
2014-08-12 11:32:42 +00:00
|
|
|
lastBitmapEntryIndex);
|
|
|
|
}
|
|
|
|
|
2014-08-05 05:51:11 +00:00
|
|
|
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
|
|
|
|
int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
|
|
|
|
for (const int wordId : prevWordIds) {
|
|
|
|
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex);
|
|
|
|
if (!result.mIsValid) {
|
|
|
|
return TrieMap::INVALID_INDEX;
|
|
|
|
}
|
|
|
|
bitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
|
|
|
|
}
|
|
|
|
return bitmapEntryIndex;
|
|
|
|
}
|
|
|
|
|
2014-08-22 11:07:54 +00:00
|
|
|
bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
|
|
|
|
const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
|
|
|
|
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
|
|
|
|
if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
|
|
|
|
AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
|
|
|
|
level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
const ProbabilityEntry probabilityEntry =
|
|
|
|
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
|
|
|
|
if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
|
|
|
|
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
|
|
|
|
probabilityEntry.getHistoricalInfo(), headerPolicy);
|
|
|
|
if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
|
|
|
|
// Update the entry.
|
|
|
|
const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
|
|
|
|
if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
|
|
|
|
bitmapEntryIndex)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Remove the entry.
|
|
|
|
if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!probabilityEntry.representsBeginningOfSentence()) {
|
|
|
|
outEntryCounts[level] += 1;
|
|
|
|
}
|
|
|
|
if (!entry.hasNextLevelMap()) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
|
|
|
|
headerPolicy, outEntryCounts)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2014-08-21 03:48:24 +00:00
|
|
|
bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
|
2014-08-27 11:04:39 +00:00
|
|
|
const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel,
|
|
|
|
int *const outEntryCount) {
|
2014-08-21 03:48:24 +00:00
|
|
|
std::vector<int> prevWordIds;
|
|
|
|
std::vector<EntryInfoToTurncate> entryInfoVector;
|
|
|
|
if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
|
|
|
|
&prevWordIds, &entryInfoVector)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
|
2014-08-27 11:04:39 +00:00
|
|
|
*outEntryCount = static_cast<int>(entryInfoVector.size());
|
2014-08-21 03:48:24 +00:00
|
|
|
return true;
|
|
|
|
}
|
2014-08-27 11:04:39 +00:00
|
|
|
*outEntryCount = maxEntryCount;
|
2014-08-21 03:48:24 +00:00
|
|
|
const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
|
|
|
|
std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
|
|
|
|
entryInfoVector.end(),
|
|
|
|
EntryInfoToTurncate::Comparator());
|
|
|
|
for (int i = 0; i < entryCountToRemove; ++i) {
|
|
|
|
const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
|
|
|
|
if (!removeNgramProbabilityEntry(
|
|
|
|
WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
|
|
|
|
const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
|
|
|
|
std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
|
|
|
|
const int currentLevel = prevWordIds->size();
|
|
|
|
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
|
|
|
|
if (currentLevel < targetLevel) {
|
|
|
|
if (!entry.hasNextLevelMap()) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
prevWordIds->push_back(entry.key());
|
|
|
|
if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(),
|
|
|
|
prevWordIds, outEntryInfo)) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
prevWordIds->pop_back();
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
const ProbabilityEntry probabilityEntry =
|
|
|
|
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
|
|
|
|
const int probability = (mHasHistoricalInfo) ?
|
|
|
|
ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
|
|
|
|
headerPolicy) : probabilityEntry.getProbability();
|
|
|
|
outEntryInfo->emplace_back(probability,
|
|
|
|
probabilityEntry.getHistoricalInfo()->getTimeStamp(),
|
|
|
|
entry.key(), targetLevel, prevWordIds->data());
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
|
|
|
|
const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
|
|
|
|
if (left.mProbability != right.mProbability) {
|
|
|
|
return left.mProbability < right.mProbability;
|
|
|
|
}
|
|
|
|
if (left.mTimestamp != right.mTimestamp) {
|
|
|
|
return left.mTimestamp > right.mTimestamp;
|
|
|
|
}
|
|
|
|
if (left.mKey != right.mKey) {
|
|
|
|
return left.mKey < right.mKey;
|
|
|
|
}
|
|
|
|
if (left.mEntryLevel != right.mEntryLevel) {
|
|
|
|
return left.mEntryLevel > right.mEntryLevel;
|
|
|
|
}
|
|
|
|
for (int i = 0; i < left.mEntryLevel; ++i) {
|
|
|
|
if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
|
|
|
|
return left.mPrevWordIds[i] < right.mPrevWordIds[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// left and rigth represent the same entry.
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
|
|
|
|
const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
|
|
|
|
: mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
|
|
|
|
memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
|
|
|
|
}
|
|
|
|
|
2014-08-01 11:19:16 +00:00
|
|
|
} // namespace latinime
|