am 21ce9c4a: Merge "Stochastic decay."

* commit '21ce9c4a89f90593e54ae29670ebd09a14533665':
  Stochastic decay.
This commit is contained in:
Keisuke Kuroyanagi 2013-10-02 06:57:13 -07:00 committed by Android Git Automerger
commit 879ae3aa92
6 changed files with 77 additions and 95 deletions

View file

@ -43,7 +43,7 @@ void DynamicBigramListPolicy::getNextBigram(int *const outBigramPos, int *const
} }
*outProbability = BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags); *outProbability = BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags);
*outHasNext = BigramListReadWriteUtils::hasNext(bigramFlags); *outHasNext = BigramListReadWriteUtils::hasNext(bigramFlags);
if (mIsDecayingDict && !ForgettingCurveUtils::isValidBigram(*outProbability)) { if (mIsDecayingDict && !ForgettingCurveUtils::isValidEncodedProbability(*outProbability)) {
// This bigram is too weak to output. // This bigram is too weak to output.
*outBigramPos = NOT_A_DICT_POS; *outBigramPos = NOT_A_DICT_POS;
} else { } else {
@ -261,8 +261,8 @@ bool DynamicBigramListPolicy::addNewBigramEntryToBigramList(const int bigramTarg
const int originalProbability = BigramListReadWriteUtils::getProbabilityFromFlags( const int originalProbability = BigramListReadWriteUtils::getProbabilityFromFlags(
bigramFlags); bigramFlags);
const int probabilityToWrite = mIsDecayingDict ? const int probabilityToWrite = mIsDecayingDict ?
ForgettingCurveUtils::getUpdatedBigramProbabilityDelta( ForgettingCurveUtils::getUpdatedEncodedProbability(originalProbability,
originalProbability, probability) : probability; probability) : probability;
const BigramListReadWriteUtils::BigramFlags updatedFlags = const BigramListReadWriteUtils::BigramFlags updatedFlags =
BigramListReadWriteUtils::setProbabilityInFlags(bigramFlags, BigramListReadWriteUtils::setProbabilityInFlags(bigramFlags,
probabilityToWrite); probabilityToWrite);
@ -294,7 +294,7 @@ bool DynamicBigramListPolicy::writeNewBigramEntry(const int bigramTargetPos, con
int *const writingPos) { int *const writingPos) {
// hasNext is false because we are adding a new bigram entry at the end of the bigram list. // hasNext is false because we are adding a new bigram entry at the end of the bigram list.
const int probabilityToWrite = mIsDecayingDict ? const int probabilityToWrite = mIsDecayingDict ?
ForgettingCurveUtils::getUpdatedBigramProbabilityDelta(NOT_A_PROBABILITY, probability) : ForgettingCurveUtils::getUpdatedEncodedProbability(NOT_A_PROBABILITY, probability) :
probability; probability;
return BigramListReadWriteUtils::createAndWriteBigramEntry(mBuffer, bigramTargetPos, return BigramListReadWriteUtils::createAndWriteBigramEntry(mBuffer, bigramTargetPos,
probabilityToWrite, false /* hasNext */, writingPos); probabilityToWrite, false /* hasNext */, writingPos);
@ -365,9 +365,9 @@ bool DynamicBigramListPolicy::updateProbabilityForDecay(
*outRemoved = false; *outRemoved = false;
if (mIsDecayingDict) { if (mIsDecayingDict) {
// Update bigram probability for decaying. // Update bigram probability for decaying.
const int newProbability = ForgettingCurveUtils::getBigramProbabilityDeltaToSave( const int newProbability = ForgettingCurveUtils::getEncodedProbabilityToSave(
BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags)); BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags));
if (ForgettingCurveUtils::isValidBigram(newProbability)) { if (ForgettingCurveUtils::isValidEncodedProbability(newProbability)) {
// Write new probability. // Write new probability.
const BigramListReadWriteUtils::BigramFlags updatedBigramFlags = const BigramListReadWriteUtils::BigramFlags updatedBigramFlags =
BigramListReadWriteUtils::setProbabilityInFlags( BigramListReadWriteUtils::setProbabilityInFlags(

View file

@ -29,14 +29,14 @@ bool DynamicPatriciaTrieGcEventListeners
bool isUselessPtNode = !node->isTerminal(); bool isUselessPtNode = !node->isTerminal();
if (node->isTerminal() && mIsDecayingDict) { if (node->isTerminal() && mIsDecayingDict) {
const int newProbability = const int newProbability =
ForgettingCurveUtils::getUnigramProbabilityToSave(node->getProbability()); ForgettingCurveUtils::getEncodedProbabilityToSave(node->getProbability());
int writingPos = node->getProbabilityFieldPos(); int writingPos = node->getProbabilityFieldPos();
// Update probability. // Update probability.
if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition( if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition(
mBuffer, newProbability, &writingPos)) { mBuffer, newProbability, &writingPos)) {
return false; return false;
} }
if (!ForgettingCurveUtils::isValidUnigram(newProbability)) { if (!ForgettingCurveUtils::isValidEncodedProbability(newProbability)) {
isUselessPtNode = false; isUselessPtNode = false;
} }
} }

View file

@ -545,7 +545,7 @@ bool DynamicPatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
int DynamicPatriciaTrieWritingHelper::getUpdatedProbability(const int originalProbability, int DynamicPatriciaTrieWritingHelper::getUpdatedProbability(const int originalProbability,
const int newProbability) { const int newProbability) {
if (mNeedsToDecay) { if (mNeedsToDecay) {
return ForgettingCurveUtils::getUpdatedUnigramProbability(originalProbability, return ForgettingCurveUtils::getUpdatedEncodedProbability(originalProbability,
newProbability); newProbability);
} else { } else {
return newProbability; return newProbability;

View file

@ -14,6 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#include <stdlib.h>
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h"
@ -26,106 +28,91 @@ const int ForgettingCurveUtils::MAX_BIGRAM_COUNT = 12000;
const int ForgettingCurveUtils::MAX_BIGRAM_COUNT_AFTER_GC = 10000; const int ForgettingCurveUtils::MAX_BIGRAM_COUNT_AFTER_GC = 10000;
const int ForgettingCurveUtils::MAX_COMPUTED_PROBABILITY = 127; const int ForgettingCurveUtils::MAX_COMPUTED_PROBABILITY = 127;
const int ForgettingCurveUtils::MAX_UNIGRAM_PROBABILITY = 120; const int ForgettingCurveUtils::MAX_ENCODED_PROBABILITY = 15;
const int ForgettingCurveUtils::MIN_VALID_UNIGRAM_PROBABILITY = 24; const int ForgettingCurveUtils::MIN_VALID_ENCODED_PROBABILITY = 3;
const int ForgettingCurveUtils::UNIGRAM_PROBABILITY_STEP = 8; const int ForgettingCurveUtils::ENCODED_PROBABILITY_STEP = 1;
const int ForgettingCurveUtils::MAX_BIGRAM_PROBABILITY_DELTA = 15; // Currently, we try to decay each uni/bigram once every 2 hours. Accordingly, the expected
const int ForgettingCurveUtils::MIN_VALID_BIGRAM_PROBABILITY_DELTA = 3; // duration of the decay is approximately 66hours.
const int ForgettingCurveUtils::BIGRAM_PROBABILITY_DELTA_STEP = 1; const float ForgettingCurveUtils::MIN_PROBABILITY_TO_DECAY = 0.03f;
/* static */ int ForgettingCurveUtils::getProbability(const int encodedUnigramProbability, /* static */ int ForgettingCurveUtils::getProbability(const int encodedUnigramProbability,
const int encodedBigramProbabilityDelta) { const int encodedBigramProbability) {
if (encodedUnigramProbability == NOT_A_PROBABILITY) { if (encodedUnigramProbability == NOT_A_PROBABILITY) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} else if (encodedBigramProbabilityDelta == NOT_A_PROBABILITY) { } else if (encodedBigramProbability == NOT_A_PROBABILITY) {
const int rawProbability = ProbabilityUtils::backoff(decodeUnigramProbability( return backoff(decodeUnigramProbability(encodedUnigramProbability));
encodedUnigramProbability));
return min(getDecayedProbability(rawProbability), MAX_COMPUTED_PROBABILITY);
} else { } else {
const int rawProbability = ProbabilityUtils::computeProbabilityForBigram( const int unigramProbability = decodeUnigramProbability(encodedUnigramProbability);
decodeUnigramProbability(encodedUnigramProbability), const int bigramProbability = decodeBigramProbability(encodedBigramProbability);
decodeBigramProbabilityDelta(encodedBigramProbabilityDelta)); return min(max(unigramProbability, bigramProbability), MAX_COMPUTED_PROBABILITY);
return min(getDecayedProbability(rawProbability), MAX_COMPUTED_PROBABILITY);
} }
} }
/* static */ int ForgettingCurveUtils::getUpdatedUnigramProbability( // Caveat: Unlike getProbability(), this method doesn't assume special bigram probability encoding
// (i.e. unigram probability + bigram probability delta).
/* static */ int ForgettingCurveUtils::getUpdatedEncodedProbability(
const int originalEncodedProbability, const int newProbability) { const int originalEncodedProbability, const int newProbability) {
if (originalEncodedProbability == NOT_A_PROBABILITY) { if (originalEncodedProbability == NOT_A_PROBABILITY) {
// The unigram is not in this dictionary.
if (newProbability == NOT_A_PROBABILITY) {
// The unigram is not in other dictionaries.
return 0;
} else {
return MIN_VALID_UNIGRAM_PROBABILITY;
}
} else {
if (newProbability != NOT_A_PROBABILITY
&& originalEncodedProbability < MIN_VALID_UNIGRAM_PROBABILITY) {
return MIN_VALID_UNIGRAM_PROBABILITY;
}
return min(originalEncodedProbability + UNIGRAM_PROBABILITY_STEP, MAX_UNIGRAM_PROBABILITY);
}
}
/* static */ int ForgettingCurveUtils::getUnigramProbabilityToSave(const int encodedProbability) {
return max(encodedProbability - UNIGRAM_PROBABILITY_STEP, 0);
}
/* static */ int ForgettingCurveUtils::getBigramProbabilityDeltaToSave(
const int encodedProbabilityDelta) {
return max(encodedProbabilityDelta - BIGRAM_PROBABILITY_DELTA_STEP, 0);
}
/* static */ int ForgettingCurveUtils::getUpdatedBigramProbabilityDelta(
const int originalEncodedProbabilityDelta, const int newProbability) {
if (originalEncodedProbabilityDelta == NOT_A_PROBABILITY) {
// The bigram relation is not in this dictionary. // The bigram relation is not in this dictionary.
if (newProbability == NOT_A_PROBABILITY) { if (newProbability == NOT_A_PROBABILITY) {
// The bigram target is not in other dictionaries. // The bigram target is not in other dictionaries.
return 0; return 0;
} else { } else {
return MIN_VALID_BIGRAM_PROBABILITY_DELTA; return MIN_VALID_ENCODED_PROBABILITY;
} }
} else { } else {
if (newProbability != NOT_A_PROBABILITY if (newProbability != NOT_A_PROBABILITY
&& originalEncodedProbabilityDelta < MIN_VALID_BIGRAM_PROBABILITY_DELTA) { && originalEncodedProbability < MIN_VALID_ENCODED_PROBABILITY) {
return MIN_VALID_BIGRAM_PROBABILITY_DELTA; return MIN_VALID_ENCODED_PROBABILITY;
} }
return min(originalEncodedProbabilityDelta + BIGRAM_PROBABILITY_DELTA_STEP, return min(originalEncodedProbability + ENCODED_PROBABILITY_STEP, MAX_ENCODED_PROBABILITY);
MAX_BIGRAM_PROBABILITY_DELTA);
} }
} }
/* static */ int ForgettingCurveUtils::isValidUnigram(const int encodedUnigramProbability) { /* static */ int ForgettingCurveUtils::isValidEncodedProbability(const int encodedProbability) {
return encodedUnigramProbability >= MIN_VALID_UNIGRAM_PROBABILITY; return encodedProbability >= MIN_VALID_ENCODED_PROBABILITY;
} }
/* static */ int ForgettingCurveUtils::isValidBigram(const int encodedBigramProbabilityDelta) { /* static */ int ForgettingCurveUtils::getEncodedProbabilityToSave(const int encodedProbability) {
return encodedBigramProbabilityDelta >= MIN_VALID_BIGRAM_PROBABILITY_DELTA; const int currentEncodedProbability = max(min(encodedProbability, MAX_ENCODED_PROBABILITY), 0);
// TODO: Implement the decay in more proper way.
const float currentRate = static_cast<float>(currentEncodedProbability)
/ static_cast<float>(MAX_ENCODED_PROBABILITY);
const float thresholdToDecay = MIN_PROBABILITY_TO_DECAY
+ (1.0f - MIN_PROBABILITY_TO_DECAY) * (1.0f - currentRate);
const float randValue = static_cast<float>(rand()) / static_cast<float>(RAND_MAX);
if (thresholdToDecay < randValue) {
return max(currentEncodedProbability - ENCODED_PROBABILITY_STEP, 0);
} else {
return currentEncodedProbability;
}
} }
/* static */ int ForgettingCurveUtils::decodeUnigramProbability(const int encodedProbability) { /* static */ int ForgettingCurveUtils::decodeUnigramProbability(const int encodedProbability) {
const int probability = encodedProbability - MIN_VALID_UNIGRAM_PROBABILITY; const int probability = encodedProbability - MIN_VALID_ENCODED_PROBABILITY;
if (probability < 0) { if (probability < 0) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} else { } else {
return min(probability, MAX_UNIGRAM_PROBABILITY); return min(probability, MAX_ENCODED_PROBABILITY) * 8;
} }
} }
/* static */ int ForgettingCurveUtils::decodeBigramProbabilityDelta( /* static */ int ForgettingCurveUtils::decodeBigramProbability(const int encodedProbability) {
const int encodedProbabilityDelta) { const int probability = encodedProbability - MIN_VALID_ENCODED_PROBABILITY;
const int probabilityDelta = encodedProbabilityDelta - MIN_VALID_BIGRAM_PROBABILITY_DELTA; if (probability < 0) {
if (probabilityDelta < 0) {
return NOT_A_PROBABILITY; return NOT_A_PROBABILITY;
} else { } else {
return min(probabilityDelta, MAX_BIGRAM_PROBABILITY_DELTA); return min(probability, MAX_ENCODED_PROBABILITY) * 8;
} }
} }
/* static */ int ForgettingCurveUtils::getDecayedProbability(const int rawProbability) { // See comments in ProbabilityUtils::backoff().
return rawProbability; /* static */ int ForgettingCurveUtils::backoff(const int unigramProbability) {
if (unigramProbability == NOT_A_PROBABILITY) {
return NOT_A_PROBABILITY;
} else {
return max(unigramProbability - 8, 0);
}
} }
} // namespace latinime } // namespace latinime

View file

@ -24,7 +24,6 @@ namespace latinime {
// TODO: Check the elapsed time and decrease the probability depending on the time. Time field is // TODO: Check the elapsed time and decrease the probability depending on the time. Time field is
// required to introduced to each terminal PtNode and bigram entry. // required to introduced to each terminal PtNode and bigram entry.
// TODO: Quit using bigram probability to indicate the delta. // TODO: Quit using bigram probability to indicate the delta.
// TODO: Quit using bigram probability delta.
class ForgettingCurveUtils { class ForgettingCurveUtils {
public: public:
static const int MAX_UNIGRAM_COUNT; static const int MAX_UNIGRAM_COUNT;
@ -33,38 +32,30 @@ class ForgettingCurveUtils {
static const int MAX_BIGRAM_COUNT_AFTER_GC; static const int MAX_BIGRAM_COUNT_AFTER_GC;
static int getProbability(const int encodedUnigramProbability, static int getProbability(const int encodedUnigramProbability,
const int encodedBigramProbabilityDelta); const int encodedBigramProbability);
static int getUpdatedUnigramProbability(const int originalEncodedProbability, static int getUpdatedEncodedProbability(const int originalEncodedProbability,
const int newProbability); const int newProbability);
static int getUpdatedBigramProbabilityDelta(const int originalEncodedProbabilityDelta, static int isValidEncodedProbability(const int encodedProbability);
const int newProbability);
static int isValidUnigram(const int encodedUnigramProbability); static int getEncodedProbabilityToSave(const int encodedProbability);
static int isValidBigram(const int encodedProbabilityDelta);
static int getUnigramProbabilityToSave(const int encodedProbability);
static int getBigramProbabilityDeltaToSave(const int encodedProbabilityDelta);
private: private:
DISALLOW_IMPLICIT_CONSTRUCTORS(ForgettingCurveUtils); DISALLOW_IMPLICIT_CONSTRUCTORS(ForgettingCurveUtils);
static const int MAX_COMPUTED_PROBABILITY; static const int MAX_COMPUTED_PROBABILITY;
static const int MAX_UNIGRAM_PROBABILITY; static const int MAX_ENCODED_PROBABILITY;
static const int MIN_VALID_UNIGRAM_PROBABILITY; static const int MIN_VALID_ENCODED_PROBABILITY;
static const int UNIGRAM_PROBABILITY_STEP; static const int ENCODED_PROBABILITY_STEP;
static const int MAX_BIGRAM_PROBABILITY_DELTA;
static const int MIN_VALID_BIGRAM_PROBABILITY_DELTA; static const float MIN_PROBABILITY_TO_DECAY;
static const int BIGRAM_PROBABILITY_DELTA_STEP;
static int decodeUnigramProbability(const int encodedProbability); static int decodeUnigramProbability(const int encodedProbability);
static int decodeBigramProbabilityDelta(const int encodedProbability); static int decodeBigramProbability(const int encodedProbability);
static int getDecayedProbability(const int rawProbability); static int backoff(const int unigramProbability);
}; };
} // namespace latinime } // namespace latinime
#endif /* LATINIME_FORGETTING_CURVE_UTILS_H */ #endif /* LATINIME_FORGETTING_CURVE_UTILS_H */

View file

@ -50,14 +50,18 @@ public class BinaryDictionaryDecayingTests extends AndroidTestCase {
} }
private void forcePassingShortTime(final BinaryDictionary binaryDictionary) { private void forcePassingShortTime(final BinaryDictionary binaryDictionary) {
// Entries having low probability would be suppressed once in 2 GCs.
final int count = 2;
for (int i = 0; i < count; i++) {
binaryDictionary.getPropertyForTests(SET_NEEDS_TO_DECAY_FOR_TESTING_KEY); binaryDictionary.getPropertyForTests(SET_NEEDS_TO_DECAY_FOR_TESTING_KEY);
binaryDictionary.flushWithGC(); binaryDictionary.flushWithGC();
} }
}
private void forcePassingLongTime(final BinaryDictionary binaryDictionary) { private void forcePassingLongTime(final BinaryDictionary binaryDictionary) {
// Currently, probabilities are decayed when GC is run. All entries that have never been // Currently, probabilities are decayed when GC is run. All entries that have never been
// typed in 32 GCs are removed. // typed in 128 GCs would be removed.
final int count = 32; final int count = 128;
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
binaryDictionary.getPropertyForTests(SET_NEEDS_TO_DECAY_FOR_TESTING_KEY); binaryDictionary.getPropertyForTests(SET_NEEDS_TO_DECAY_FOR_TESTING_KEY);
binaryDictionary.flushWithGC(); binaryDictionary.flushWithGC();