Support unigram historical information migration.

Bug: 13406708
Change-Id: Ibed15b3bc5d5ae68faefa379028dbe10d32b0c0f
main
Keisuke Kuroyanagi 2014-05-12 19:21:06 +09:00
parent 6b74f516dc
commit 9d7e8c717f
7 changed files with 110 additions and 27 deletions

View File

@ -335,8 +335,9 @@ static void latinime_BinaryDictionary_addUnigramWord(JNIEnv *env, jclass clazz,
if (!shortcutTargetCodePoints.empty()) { if (!shortcutTargetCodePoints.empty()) {
shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability);
} }
// Use 1 for count to indicate the word has inputed.
const UnigramProperty unigramProperty(isNotAWord, isBlacklisted, const UnigramProperty unigramProperty(isNotAWord, isBlacklisted,
probability, timestamp, 0 /* level */, 0 /* count */, &shortcuts); probability, timestamp, 0 /* level */, 1 /* count */, &shortcuts);
dictionary->addUnigramWord(codePoints, codePointCount, &unigramProperty); dictionary->addUnigramWord(codePoints, codePointCount, &unigramProperty);
} }
@ -436,8 +437,9 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j
env->GetIntField(languageModelParam, shortcutProbabilityFieldId); env->GetIntField(languageModelParam, shortcutProbabilityFieldId);
shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability);
} }
// Use 1 for count to indicate the word has inputed.
const UnigramProperty unigramProperty(isNotAWord, isBlacklisted, const UnigramProperty unigramProperty(isNotAWord, isBlacklisted,
unigramProbability, timestamp, 0 /* level */, 0 /* count */, &shortcuts); unigramProbability, timestamp, 0 /* level */, 1 /* count */, &shortcuts);
dictionary->addUnigramWord(word1CodePoints, word1Length, &unigramProperty); dictionary->addUnigramWord(word1CodePoints, word1Length, &unigramProperty);
if (word0) { if (word0) {
jint bigramProbability = env->GetIntField(languageModelParam, bigramProbabilityFieldId); jint bigramProbability = env->GetIntField(languageModelParam, bigramProbabilityFieldId);

View File

@ -257,10 +257,12 @@ const BigramEntry Ver4BigramListPolicy::createUpdatedBigramEntryFrom(
const int timestamp) const { const int timestamp) const {
// TODO: Consolidate historical info and probability. // TODO: Consolidate historical info and probability.
if (mHeaderPolicy->hasHistoricalInfoOfWords()) { if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
// Use 1 for count to indicate the bigram has inputed.
const HistoricalInfo historicalInfoForUpdate(timestamp, 0 /* level */, 1 /* count */);
const HistoricalInfo updatedHistoricalInfo = const HistoricalInfo updatedHistoricalInfo =
ForgettingCurveUtils::createUpdatedHistoricalInfo( ForgettingCurveUtils::createUpdatedHistoricalInfo(
originalBigramEntry->getHistoricalInfo(), newProbability, timestamp, originalBigramEntry->getHistoricalInfo(), newProbability,
mHeaderPolicy); &historicalInfoForUpdate, mHeaderPolicy);
return originalBigramEntry->updateHistoricalInfoAndGetEntry(&updatedHistoricalInfo); return originalBigramEntry->updateHistoricalInfoAndGetEntry(&updatedHistoricalInfo);
} else { } else {
return originalBigramEntry->updateProbabilityAndGetEntry(newProbability); return originalBigramEntry->updateProbabilityAndGetEntry(newProbability);

View File

@ -387,11 +387,12 @@ const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
const UnigramProperty *const unigramProperty) const { const UnigramProperty *const unigramProperty) const {
// TODO: Consolidate historical info and probability. // TODO: Consolidate historical info and probability.
if (mHeaderPolicy->hasHistoricalInfoOfWords()) { if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(),
unigramProperty->getLevel(), unigramProperty->getCount());
const HistoricalInfo updatedHistoricalInfo = const HistoricalInfo updatedHistoricalInfo =
ForgettingCurveUtils::createUpdatedHistoricalInfo( ForgettingCurveUtils::createUpdatedHistoricalInfo(
originalProbabilityEntry->getHistoricalInfo(), originalProbabilityEntry->getHistoricalInfo(),
unigramProperty->getProbability(), unigramProperty->getTimestamp(), unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy);
mHeaderPolicy);
return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo( return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo(
&updatedHistoricalInfo); &updatedHistoricalInfo);
} else { } else {

View File

@ -425,6 +425,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code
} }
int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) { int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) {
// TODO: Return code point count like other methods.
// Null termination.
outCodePoints[0] = 0;
if (token == 0) { if (token == 0) {
mTerminalPtNodePositionsForIteratingWords.clear(); mTerminalPtNodePositionsForIteratingWords.clear();
DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy( DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy(
@ -441,8 +444,13 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
} }
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
int unigramProbability = NOT_A_PROBABILITY; int unigramProbability = NOT_A_PROBABILITY;
getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH, const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
outCodePoints, &unigramProbability); terminalPtNodePos, MAX_WORD_LENGTH, outCodePoints, &unigramProbability);
if (codePointCount < MAX_WORD_LENGTH) {
// Null termination. outCodePoints have to be null terminated or contain MAX_WORD_LENGTH
// code points.
outCodePoints[codePointCount] = 0;
}
const int nextToken = token + 1; const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) { if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated. // All words have been iterated.

View File

@ -30,7 +30,7 @@ const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8;
const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60; const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60;
const int ForgettingCurveUtils::MAX_LEVEL = 3; const int ForgettingCurveUtils::MAX_LEVEL = 3;
const int ForgettingCurveUtils::MIN_VALID_LEVEL = 1; const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1;
const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15; const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15;
const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14; const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14;
@ -41,25 +41,34 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
// TODO: Revise the logic to decide the initial probability depending on the given probability. // TODO: Revise the logic to decide the initial probability depending on the given probability.
/* static */ const HistoricalInfo ForgettingCurveUtils::createUpdatedHistoricalInfo( /* static */ const HistoricalInfo ForgettingCurveUtils::createUpdatedHistoricalInfo(
const HistoricalInfo *const originalHistoricalInfo, const HistoricalInfo *const originalHistoricalInfo, const int newProbability,
const int newProbability, const int timestamp, const HeaderPolicy *const headerPolicy) { const HistoricalInfo *const newHistoricalInfo, const HeaderPolicy *const headerPolicy) {
const int timestamp = newHistoricalInfo->getTimeStamp();
if (newProbability != NOT_A_PROBABILITY && originalHistoricalInfo->getLevel() == 0) { if (newProbability != NOT_A_PROBABILITY && originalHistoricalInfo->getLevel() == 0) {
return HistoricalInfo(timestamp, MIN_VALID_LEVEL /* level */, 0 /* count */); // Add entry as a valid word.
} else if (!originalHistoricalInfo->isValid()) { const int level = clampToVisibleEntryLevelRange(newHistoricalInfo->getLevel());
const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
return HistoricalInfo(timestamp, level, count);
} else if (!originalHistoricalInfo->isValid()
|| originalHistoricalInfo->getLevel() < newHistoricalInfo->getLevel()
|| (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel()
&& originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) {
// Initial information. // Initial information.
return HistoricalInfo(timestamp, 0 /* level */, 1 /* count */); const int level = clampToValidLevelRange(newHistoricalInfo->getLevel());
const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
return HistoricalInfo(timestamp, level, count);
} else { } else {
const int updatedCount = originalHistoricalInfo->getCount() + 1; const int updatedCount = originalHistoricalInfo->getCount() + 1;
if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) { if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) {
// The count exceeds the max value the level can be incremented. // The count exceeds the max value the level can be incremented.
if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) { if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) {
// The level is already max. // The level is already max.
return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), return HistoricalInfo(timestamp,
originalHistoricalInfo->getCount()); originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount());
} else { } else {
// Level up. // Level up.
return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel() + 1, return HistoricalInfo(timestamp,
0 /* count */); originalHistoricalInfo->getLevel() + 1, 0 /* count */);
} }
} else { } else {
return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), updatedCount); return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), updatedCount);
@ -73,8 +82,8 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
headerPolicy->getForgettingCurveDurationToLevelDown()); headerPolicy->getForgettingCurveDurationToLevelDown());
return sProbabilityTable.getProbability( return sProbabilityTable.getProbability(
headerPolicy->getForgettingCurveProbabilityValuesTableId(), headerPolicy->getForgettingCurveProbabilityValuesTableId(),
std::min(std::max(historicalInfo->getLevel(), 0), MAX_LEVEL), clampToValidLevelRange(historicalInfo->getLevel()),
std::min(std::max(elapsedTimeStepCount, 0), MAX_ELAPSED_TIME_STEP_COUNT)); clampToValidTimeStepCountRange(elapsedTimeStepCount));
} }
/* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability, /* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability,
@ -155,6 +164,23 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
return elapsedTimeInSeconds / timeStepDurationInSeconds; return elapsedTimeInSeconds / timeStepDurationInSeconds;
} }
/* static */ int ForgettingCurveUtils::clampToVisibleEntryLevelRange(const int level) {
return std::min(std::max(level, MIN_VISIBLE_LEVEL), MAX_LEVEL);
}
/* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count,
const HeaderPolicy *const headerPolicy) {
return std::min(std::max(count, 0), headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1);
}
/* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) {
return std::min(std::max(level, 0), MAX_LEVEL);
}
/* static */ int ForgettingCurveUtils::clampToValidTimeStepCountRange(const int timeStepCount) {
return std::min(std::max(timeStepCount, 0), MAX_ELAPSED_TIME_STEP_COUNT);
}
const int ForgettingCurveUtils::ProbabilityTable::PROBABILITY_TABLE_COUNT = 4; const int ForgettingCurveUtils::ProbabilityTable::PROBABILITY_TABLE_COUNT = 4;
const int ForgettingCurveUtils::ProbabilityTable::WEAK_PROBABILITY_TABLE_ID = 0; const int ForgettingCurveUtils::ProbabilityTable::WEAK_PROBABILITY_TABLE_ID = 0;
const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID = 1; const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID = 1;

View File

@ -30,7 +30,7 @@ class ForgettingCurveUtils {
public: public:
static const HistoricalInfo createUpdatedHistoricalInfo( static const HistoricalInfo createUpdatedHistoricalInfo(
const HistoricalInfo *const originalHistoricalInfo, const int newProbability, const HistoricalInfo *const originalHistoricalInfo, const int newProbability,
const int timestamp, const HeaderPolicy *const headerPolicy); const HistoricalInfo *const newHistoricalInfo, const HeaderPolicy *const headerPolicy);
static const HistoricalInfo createHistoricalInfoToSave( static const HistoricalInfo createHistoricalInfoToSave(
const HistoricalInfo *const originalHistoricalInfo, const HistoricalInfo *const originalHistoricalInfo,
@ -93,7 +93,7 @@ class ForgettingCurveUtils {
static const int DECAY_INTERVAL_SECONDS; static const int DECAY_INTERVAL_SECONDS;
static const int MAX_LEVEL; static const int MAX_LEVEL;
static const int MIN_VALID_LEVEL; static const int MIN_VISIBLE_LEVEL;
static const int MAX_ELAPSED_TIME_STEP_COUNT; static const int MAX_ELAPSED_TIME_STEP_COUNT;
static const int DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD; static const int DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
@ -103,8 +103,11 @@ class ForgettingCurveUtils {
static const ProbabilityTable sProbabilityTable; static const ProbabilityTable sProbabilityTable;
static int backoff(const int unigramProbability); static int backoff(const int unigramProbability);
static int getElapsedTimeStepCount(const int timestamp, const int durationToLevelDown); static int getElapsedTimeStepCount(const int timestamp, const int durationToLevelDown);
static int clampToVisibleEntryLevelRange(const int level);
static int clampToValidLevelRange(const int level);
static int clampToValidCountRange(const int count, const HeaderPolicy *const headerPolicy);
static int clampToValidTimeStepCountRange(const int timeStepCount);
}; };
} // namespace latinime } // namespace latinime
#endif /* LATINIME_FORGETTING_CURVE_UTILS_H */ #endif /* LATINIME_FORGETTING_CURVE_UTILS_H */

View File

@ -93,15 +93,17 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
private File createEmptyDictionaryAndGetFile(final String dictId, private File createEmptyDictionaryAndGetFile(final String dictId,
final int formatVersion) throws IOException { final int formatVersion) throws IOException {
if (formatVersion == FormatSpec.VERSION4) { if (formatVersion == FormatSpec.VERSION4
return createEmptyVer4DictionaryAndGetFile(dictId); || formatVersion == FormatSpec.VERSION4_ONLY_FOR_TESTING) {
return createEmptyVer4DictionaryAndGetFile(dictId, formatVersion);
} else { } else {
throw new IOException("Dictionary format version " + formatVersion throw new IOException("Dictionary format version " + formatVersion
+ " is not supported."); + " is not supported.");
} }
} }
private File createEmptyVer4DictionaryAndGetFile(final String dictId) throws IOException { private File createEmptyVer4DictionaryAndGetFile(final String dictId, final int formatVersion)
throws IOException {
final File file = File.createTempFile(dictId, TEST_DICT_FILE_EXTENSION, final File file = File.createTempFile(dictId, TEST_DICT_FILE_EXTENSION,
getContext().getCacheDir()); getContext().getCacheDir());
FileUtils.deleteRecursively(file); FileUtils.deleteRecursively(file);
@ -113,7 +115,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
DictionaryHeader.ATTRIBUTE_VALUE_TRUE); DictionaryHeader.ATTRIBUTE_VALUE_TRUE);
attributeMap.put(DictionaryHeader.HAS_HISTORICAL_INFO_KEY, attributeMap.put(DictionaryHeader.HAS_HISTORICAL_INFO_KEY,
DictionaryHeader.ATTRIBUTE_VALUE_TRUE); DictionaryHeader.ATTRIBUTE_VALUE_TRUE);
if (BinaryDictionaryUtils.createEmptyDictFile(file.getAbsolutePath(), FormatSpec.VERSION4, if (BinaryDictionaryUtils.createEmptyDictFile(file.getAbsolutePath(), formatVersion,
LocaleUtils.constructLocaleFromString(TEST_LOCALE), attributeMap)) { LocaleUtils.constructLocaleFromString(TEST_LOCALE), attributeMap)) {
return file; return file;
} else { } else {
@ -562,4 +564,43 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
} }
} }
} }
public void testDictMigration() {
testDictMigration(FormatSpec.VERSION4_ONLY_FOR_TESTING, FormatSpec.VERSION4);
}
private void testDictMigration(final int fromFormatVersion, final int toFormatVersion) {
setCurrentTimeForTestMode(mCurrentTime);
File dictFile = null;
try {
dictFile = createEmptyDictionaryAndGetFile("TestBinaryDictionary", fromFormatVersion);
} catch (IOException e) {
fail("IOException while writing an initial dictionary : " + e);
}
final BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(),
0 /* offset */, dictFile.length(), true /* useFullEditDistance */,
Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */);
// TODO: Add tests for bigrams when the implementation gets ready.
addUnigramWord(binaryDictionary, "aaa", DUMMY_PROBABILITY);
assertTrue(binaryDictionary.isValidWord("aaa"));
addUnigramWord(binaryDictionary, "bbb", Dictionary.NOT_A_PROBABILITY);
assertFalse(binaryDictionary.isValidWord("bbb"));
addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY);
addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY);
addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY);
addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY);
addUnigramWord(binaryDictionary, "ccc", DUMMY_PROBABILITY);
assertEquals(fromFormatVersion, binaryDictionary.getFormatVersion());
assertTrue(binaryDictionary.migrateTo(toFormatVersion));
assertTrue(binaryDictionary.isValidDictionary());
assertEquals(toFormatVersion, binaryDictionary.getFormatVersion());
assertTrue(binaryDictionary.isValidWord("aaa"));
assertFalse(binaryDictionary.isValidWord("bbb"));
assertTrue(binaryDictionary.getFrequency("aaa") < binaryDictionary.getFrequency("ccc"));
addUnigramWord(binaryDictionary, "bbb", Dictionary.NOT_A_PROBABILITY);
assertTrue(binaryDictionary.isValidWord("bbb"));
binaryDictionary.close();
dictFile.delete();
}
} }