am 962d2c08: Discard useless unigrams when overflowing.

* commit '962d2c083aa137bfa6963c7da9cab018c81cb9ed':
  Discard useless unigrams when overflowing.
main
Keisuke Kuroyanagi 2013-12-10 01:48:31 -08:00 committed by Android Git Automerger
commit c7f1074773
4 changed files with 160 additions and 4 deletions

View File

@ -121,8 +121,14 @@ bool Ver4PatriciaTrieNodeWriter::markPtNodeAsWillBecomeNonTerminal(
const PatriciaTrieReadingUtils::NodeFlags updatedFlags = const PatriciaTrieReadingUtils::NodeFlags updatedFlags =
DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */, DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */,
false /* isDeleted */, true /* willBecomeNonTerminal */); false /* isDeleted */, true /* willBecomeNonTerminal */);
int writingPos = toBeUpdatedPtNodeParams->getHeadPos(); if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition(
toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */)) {
AKLOGE("Cannot update terminal position lookup table. terminal id: %d",
toBeUpdatedPtNodeParams->getTerminalId());
return false;
}
// Update flags. // Update flags.
int writingPos = toBeUpdatedPtNodeParams->getHeadPos();
return DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags, return DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags,
&writingPos); &writingPos);
} }

View File

@ -17,6 +17,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h"
#include <cstring> #include <cstring>
#include <queue>
#include "suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.h"
#include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/header/header_policy.h"
@ -97,10 +98,16 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) { &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) {
return false; return false;
} }
const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted
.getValidUnigramCount();
if (headerPolicy->isDecayingDict() if (headerPolicy->isDecayingDict()
&& traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted && unigramCount > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) {
.getValidUnigramCount() > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) { if (!turncateUnigrams(&ptNodeReader, &ptNodeWriter,
// TODO: Remove more unigrams. ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC)) {
AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount,
ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC);
return false;
}
} }
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
@ -179,6 +186,42 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return true; return true;
} }
bool Ver4PatriciaTrieWritingHelper::turncateUnigrams(
const Ver4PatriciaTrieNodeReader *const ptNodeReader,
Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
const TerminalPositionLookupTable *const terminalPosLookupTable =
mBuffers->getTerminalPositionLookupTable();
const int nextTerminalId = terminalPosLookupTable->getNextTerminalId();
std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator>
priorityQueue;
for (int i = 0; i < nextTerminalId; ++i) {
const int terminalPos = terminalPosLookupTable->getTerminalPtNodePosition(i);
if (terminalPos == NOT_A_DICT_POS) {
continue;
}
const ProbabilityEntry probabilityEntry =
mBuffers->getProbabilityDictContent()->getProbabilityEntry(i);
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo()) :
probabilityEntry.getProbability();
priorityQueue.push(DictProbability(terminalPos, probability,
probabilityEntry.getHistoricalInfo()->getTimeStamp()));
}
// Delete unigrams.
while (static_cast<int>(priorityQueue.size()) > maxUnigramCount) {
const int ptNodePos = priorityQueue.top().getDictPos();
const PtNodeParams ptNodeParams =
ptNodeReader->fetchNodeInfoInBufferFromPtNodePos(ptNodePos);
if (!ptNodeWriter->markPtNodeAsWillBecomeNonTerminal(&ptNodeParams)) {
AKLOGE("Cannot mark PtNode as willBecomeNonterminal. PtNode pos: %d", ptNodePos);
return false;
}
priorityQueue.pop();
}
return true;
}
bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds
::onVisitingPtNode(const PtNodeParams *const ptNodeParams) { ::onVisitingPtNode(const PtNodeParams *const ptNodeParams) {
if (!ptNodeParams->isTerminal()) { if (!ptNodeParams->isTerminal()) {

View File

@ -25,6 +25,7 @@ namespace latinime {
class HeaderPolicy; class HeaderPolicy;
class Ver4DictBuffers; class Ver4DictBuffers;
class Ver4PatriciaTrieNodeReader;
class Ver4PatriciaTrieNodeWriter; class Ver4PatriciaTrieNodeWriter;
class Ver4PatriciaTrieWritingHelper { class Ver4PatriciaTrieWritingHelper {
@ -64,10 +65,56 @@ class Ver4PatriciaTrieWritingHelper {
const TerminalPositionLookupTable::TerminalIdMap *const mTerminalIdMap; const TerminalPositionLookupTable::TerminalIdMap *const mTerminalIdMap;
}; };
// For truncateUnigrams().
class DictProbability {
public:
DictProbability(const int dictPos, const int probability, const int timestamp)
: mDictPos(dictPos), mProbability(probability), mTimestamp(timestamp) {}
int getDictPos() const {
return mDictPos;
}
int getProbability() const {
return mProbability;
}
int getTimestamp() const {
return mTimestamp;
}
private:
DISALLOW_DEFAULT_CONSTRUCTOR(DictProbability);
int mDictPos;
int mProbability;
int mTimestamp;
};
// For truncateUnigrams().
class DictProbabilityComparator {
public:
bool operator()(const DictProbability &left, const DictProbability &right) {
if (left.getProbability() != right.getProbability()) {
return left.getProbability() > right.getProbability();
}
if (left.getTimestamp() != right.getTimestamp()) {
return left.getTimestamp() < right.getTimestamp();
}
return left.getDictPos() > right.getDictPos();
}
private:
DISALLOW_ASSIGNMENT_OPERATOR(DictProbabilityComparator);
};
bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy,
Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount, Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount,
int *const outBigramCount); int *const outBigramCount);
bool turncateUnigrams(const Ver4PatriciaTrieNodeReader *const ptNodeReader,
Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount);
Ver4DictBuffers *const mBuffers; Ver4DictBuffers *const mBuffers;
}; };
} // namespace latinime } // namespace latinime

View File

@ -303,6 +303,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(), BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(),
0 /* offset */, dictFile.length(), true /* useFullEditDistance */, 0 /* offset */, dictFile.length(), true /* useFullEditDistance */,
Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */); Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */);
setCurrentTime(binaryDictionary, mCurrentTime);
final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random); final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random);
final ArrayList<String> words = new ArrayList<String>(); final ArrayList<String> words = new ArrayList<String>();
@ -339,7 +340,65 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
forcePassingLongTime(binaryDictionary); forcePassingLongTime(binaryDictionary);
assertEquals(0, Integer.parseInt(binaryDictionary.getPropertyForTests( assertEquals(0, Integer.parseInt(binaryDictionary.getPropertyForTests(
BinaryDictionary.UNIGRAM_COUNT_QUERY))); BinaryDictionary.UNIGRAM_COUNT_QUERY)));
}
public void testOverflowUnigrams() {
testOverflowUnigrams(FormatSpec.VERSION4);
}
private void testOverflowUnigrams(final int formatVersion) {
final int unigramCount = 20000;
final int eachUnigramTypedCount = 5;
final int strongUnigramTypedCount = 20;
final int weakUnigramTypedCount = 1;
final int codePointSetSize = 50;
final long seed = System.currentTimeMillis();
final Random random = new Random(seed);
File dictFile = null;
try {
dictFile = createEmptyDictionaryAndGetFile("TestBinaryDictionary", formatVersion);
} catch (IOException e) {
fail("IOException while writing an initial dictionary : " + e);
}
BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(),
0 /* offset */, dictFile.length(), true /* useFullEditDistance */,
Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */);
setCurrentTime(binaryDictionary, mCurrentTime);
final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random);
final String strong = "strong";
final String weak = "weak";
for (int j = 0; j < strongUnigramTypedCount; j++) {
addUnigramWord(binaryDictionary, strong, DUMMY_PROBABILITY);
}
for (int j = 0; j < weakUnigramTypedCount; j++) {
addUnigramWord(binaryDictionary, weak, DUMMY_PROBABILITY);
}
assertTrue(binaryDictionary.isValidWord(strong));
assertTrue(binaryDictionary.isValidWord(weak));
for (int i = 0; i < unigramCount; i++) {
final String word = CodePointUtils.generateWord(random, codePointSet);
for (int j = 0; j < eachUnigramTypedCount; j++) {
addUnigramWord(binaryDictionary, word, DUMMY_PROBABILITY);
}
if (binaryDictionary.needsToRunGC(true /* mindsBlockByGC */)) {
final int unigramCountBeforeGC =
Integer.parseInt(binaryDictionary.getPropertyForTests(
BinaryDictionary.UNIGRAM_COUNT_QUERY));
assertTrue(binaryDictionary.isValidWord(strong));
assertTrue(binaryDictionary.isValidWord(weak));
binaryDictionary.flushWithGC();
final int unigramCountAfterGC =
Integer.parseInt(binaryDictionary.getPropertyForTests(
BinaryDictionary.UNIGRAM_COUNT_QUERY));
assertTrue(unigramCountBeforeGC > unigramCountAfterGC);
assertFalse(binaryDictionary.isValidWord(weak));
assertTrue(binaryDictionary.isValidWord(strong));
break;
}
}
} }
public void testAddManyBigramsToDecayingDict() { public void testAddManyBigramsToDecayingDict() {
@ -363,6 +422,7 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(), BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(),
0 /* offset */, dictFile.length(), true /* useFullEditDistance */, 0 /* offset */, dictFile.length(), true /* useFullEditDistance */,
Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */); Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */);
setCurrentTime(binaryDictionary, mCurrentTime);
final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random); final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random);
final ArrayList<String> words = new ArrayList<String>(); final ArrayList<String> words = new ArrayList<String>();