am 18f65017: am bfba64bc: Merge "Compute the correct frequency for bigram prediction" into jb-dev

* commit '18f650172d29800edb772d3798391b2d430426df':
  Compute the correct frequency for bigram prediction
main
Jean Chalard 2012-05-29 00:40:32 -07:00 committed by Android Git Automerger
commit 01fcf0dab0
2 changed files with 9 additions and 4 deletions

View File

@ -117,14 +117,17 @@ int BigramDictionary::getBigrams(const int32_t *prevWord, int prevWordLength, in
do { do {
bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
uint16_t bigramBuffer[MAX_WORD_LENGTH]; uint16_t bigramBuffer[MAX_WORD_LENGTH];
int unigramFreq;
const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags, const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags,
&pos); &pos);
const int length = BinaryFormat::getWordAtAddress(root, bigramPos, MAX_WORD_LENGTH, const int length = BinaryFormat::getWordAtAddress(root, bigramPos, MAX_WORD_LENGTH,
bigramBuffer); bigramBuffer, &unigramFreq);
// codesSize == 0 means we are trying to find bigram predictions. // codesSize == 0 means we are trying to find bigram predictions.
if (codesSize < 1 || checkFirstCharacter(bigramBuffer)) { if (codesSize < 1 || checkFirstCharacter(bigramBuffer)) {
const int frequency = UnigramDictionary::MASK_ATTRIBUTE_FREQUENCY & bigramFlags; const int bigramFreq = UnigramDictionary::MASK_ATTRIBUTE_FREQUENCY & bigramFlags;
const int frequency =
BinaryFormat::computeFrequencyForBigram(unigramFreq, bigramFreq);
if (addWordBigram(bigramBuffer, length, frequency)) { if (addWordBigram(bigramBuffer, length, frequency)) {
++bigramCount; ++bigramCount;
} }

View File

@ -66,7 +66,7 @@ class BinaryFormat {
static int getTerminalPosition(const uint8_t* const root, const int32_t* const inWord, static int getTerminalPosition(const uint8_t* const root, const int32_t* const inWord,
const int length); const int length);
static int getWordAtAddress(const uint8_t* const root, const int address, const int maxDepth, static int getWordAtAddress(const uint8_t* const root, const int address, const int maxDepth,
uint16_t* outWord); uint16_t* outWord, int* outUnigramFrequency);
static int computeFrequencyForBigram(const int unigramFreq, const int bigramFreq); static int computeFrequencyForBigram(const int unigramFreq, const int bigramFreq);
static int getProbability(const int position, const std::map<int, int> *bigramMap, static int getProbability(const int position, const std::map<int, int> *bigramMap,
const uint8_t *bigramFilter, const int unigramFreq); const uint8_t *bigramFilter, const int unigramFreq);
@ -391,10 +391,11 @@ inline int BinaryFormat::getTerminalPosition(const uint8_t* const root,
* address: the byte position of the last chargroup of the word we are searching for (this is * address: the byte position of the last chargroup of the word we are searching for (this is
* what is stored as the "bigram address" in each bigram) * what is stored as the "bigram address" in each bigram)
* outword: an array to write the found word, with MAX_WORD_LENGTH size. * outword: an array to write the found word, with MAX_WORD_LENGTH size.
* outUnigramFrequency: a pointer to an int to write the frequency into.
* Return value : the length of the word, of 0 if the word was not found. * Return value : the length of the word, of 0 if the word was not found.
*/ */
inline int BinaryFormat::getWordAtAddress(const uint8_t* const root, const int address, inline int BinaryFormat::getWordAtAddress(const uint8_t* const root, const int address,
const int maxDepth, uint16_t* outWord) { const int maxDepth, uint16_t* outWord, int* outUnigramFrequency) {
int pos = 0; int pos = 0;
int wordPos = 0; int wordPos = 0;
@ -427,6 +428,7 @@ inline int BinaryFormat::getWordAtAddress(const uint8_t* const root, const int a
nextChar = getCharCodeAndForwardPointer(root, &pos); nextChar = getCharCodeAndForwardPointer(root, &pos);
} }
} }
*outUnigramFrequency = readFrequencyWithoutMovingPointer(root, pos);
return ++wordPos; return ++wordPos;
} }
// We need to skip past this char group, so skip any remaining chars after the // We need to skip past this char group, so skip any remaining chars after the