am b6f286bf: Merge "Make bigram dictionary and traverse session use structure policy."

* commit 'b6f286bfa549ed91c67d591fc1725e35b114742b':
  Make bigram dictionary and traverse session use structure policy.
This commit is contained in:
Keisuke Kuroynagi 2013-07-15 19:47:55 -07:00 committed by Android Git Automerger
commit df266ac7aa
6 changed files with 40 additions and 46 deletions

View file

@ -123,9 +123,10 @@ int BigramDictionary::getPredictions(const int *prevWord, int prevWordLength, in
for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos); for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos);
bigramsIt.hasNext(); /* no-op */) { bigramsIt.hasNext(); /* no-op */) {
bigramsIt.next(); bigramsIt.next();
const int length = BinaryFormat::getWordAtAddress( const int length = mBinaryDictionaryInfo->getStructurePolicy()->
mBinaryDictionaryInfo->getDictRoot(), bigramsIt.getBigramPos(), getCodePointsAndProbabilityAndReturnCodePointCount(
MAX_WORD_LENGTH, bigramBuffer, &unigramProbability); mBinaryDictionaryInfo, bigramsIt.getBigramPos(), MAX_WORD_LENGTH,
bigramBuffer, &unigramProbability);
// inputSize == 0 means we are trying to find bigram predictions. // inputSize == 0 means we are trying to find bigram predictions.
if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) { if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) {
@ -153,18 +154,8 @@ int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const in
int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch); mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch);
if (NOT_VALID_WORD == pos) return 0; if (NOT_VALID_WORD == pos) return 0;
const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot(); return BinaryFormat::getBigramListPositionForWordPosition(
const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); mBinaryDictionaryInfo->getDictRoot(), pos);
if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) return 0;
if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) {
BinaryFormat::getCodePointAndForwardPointer(root, &pos);
} else {
pos = BinaryFormat::skipOtherCharacters(root, pos);
}
pos = BinaryFormat::skipProbability(flags, pos);
pos = BinaryFormat::skipChildrenPosition(flags, pos);
pos = BinaryFormat::skipShortcuts(root, flags, pos);
return pos;
} }
bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const { bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const {

View file

@ -71,8 +71,9 @@ class BinaryFormat {
static bool hasChildrenInFlags(const uint8_t flags); static bool hasChildrenInFlags(const uint8_t flags);
static int getTerminalPosition(const uint8_t *const root, const int *const inWord, static int getTerminalPosition(const uint8_t *const root, const int *const inWord,
const int length, const bool forceLowerCaseSearch); const int length, const bool forceLowerCaseSearch);
static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth, static int getCodePointsAndProbabilityAndReturnCodePointCount(
int *outWord, int *outUnigramProbability); const uint8_t *const root, const int nodePos, const int maxCodePointCount,
int *outCodePoints, int *outUnigramProbability);
static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);
private: private:
@ -342,8 +343,9 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root,
* outUnigramProbability: a pointer to an int to write the probability into. * outUnigramProbability: a pointer to an int to write the probability into.
* Return value : the length of the word, of 0 if the word was not found. * Return value : the length of the word, of 0 if the word was not found.
*/ */
AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address, AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount(
const int maxDepth, int *outWord, int *outUnigramProbability) { const uint8_t *const root, const int nodePos,
const int maxCodePointCount, int *outCodePoints, int *outUnigramProbability) {
int pos = 0; int pos = 0;
int wordPos = 0; int wordPos = 0;
@ -353,7 +355,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
// The only reason we count nodes is because we want to reduce the probability of infinite // The only reason we count nodes is because we want to reduce the probability of infinite
// looping in case there is a bug. Since we know there is an upper bound to the depth we are // looping in case there is a bug. Since we know there is an upper bound to the depth we are
// supposed to traverse, it does not hurt to count iterations. // supposed to traverse, it does not hurt to count iterations.
for (int loopCount = maxDepth; loopCount > 0; --loopCount) { for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) {
int lastCandidateGroupPos = 0; int lastCandidateGroupPos = 0;
// Let's loop through char groups in this node searching for either the terminal // Let's loop through char groups in this node searching for either the terminal
// or one of its ascendants. // or one of its ascendants.
@ -362,17 +364,17 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
const int startPos = pos; const int startPos = pos;
const uint8_t flags = getFlagsAndForwardPointer(root, &pos); const uint8_t flags = getFlagsAndForwardPointer(root, &pos);
const int character = getCodePointAndForwardPointer(root, &pos); const int character = getCodePointAndForwardPointer(root, &pos);
if (address == startPos) { if (nodePos == startPos) {
// We found the address. Copy the rest of the word in the buffer and return // We found the address. Copy the rest of the word in the buffer and return
// the length. // the length.
outWord[wordPos] = character; outCodePoints[wordPos] = character;
if (FLAG_HAS_MULTIPLE_CHARS & flags) { if (FLAG_HAS_MULTIPLE_CHARS & flags) {
int nextChar = getCodePointAndForwardPointer(root, &pos); int nextChar = getCodePointAndForwardPointer(root, &pos);
// We count chars in order to avoid infinite loops if the file is broken or // We count chars in order to avoid infinite loops if the file is broken or
// if there is some other bug // if there is some other bug
int charCount = maxDepth; int charCount = maxCodePointCount;
while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
outWord[++wordPos] = nextChar; outCodePoints[++wordPos] = nextChar;
nextChar = getCodePointAndForwardPointer(root, &pos); nextChar = getCodePointAndForwardPointer(root, &pos);
} }
} }
@ -399,7 +401,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
if (hasChildren) { if (hasChildren) {
// Here comes the tricky part. First, read the children position. // Here comes the tricky part. First, read the children position.
const int childrenPos = readChildrenPosition(root, flags, pos); const int childrenPos = readChildrenPosition(root, flags, pos);
if (childrenPos > address) { if (childrenPos > nodePos) {
// If the children pos is greater than address, it means the previous chargroup, // If the children pos is greater than address, it means the previous chargroup,
// which address is stored in lastCandidateGroupPos, was the right one. // which address is stored in lastCandidateGroupPos, was the right one.
found = true; found = true;
@ -429,12 +431,12 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
const int lastChar = const int lastChar =
getCodePointAndForwardPointer(root, &lastCandidateGroupPos); getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
// We copy all the characters in this group to the buffer // We copy all the characters in this group to the buffer
outWord[wordPos] = lastChar; outCodePoints[wordPos] = lastChar;
if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) { if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) {
int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
int charCount = maxDepth; int charCount = maxCodePointCount;
while (-1 != nextChar && --charCount > 0) { while (-1 != nextChar && --charCount > 0) {
outWord[++wordPos] = nextChar; outCodePoints[++wordPos] = nextChar;
nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
} }
} }

View file

@ -50,8 +50,9 @@ class DictionaryStructurePolicy {
const BinaryDictionaryInfo *const binaryDictionaryInfo, const BinaryDictionaryInfo *const binaryDictionaryInfo,
const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const = 0; const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const = 0;
virtual void getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo, virtual int getCodePointsAndProbabilityAndReturnCodePointCount(
const int terminalNodePos, const int maxDepth, int *const outWord, const BinaryDictionaryInfo *const binaryDictionaryInfo,
const int nodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const = 0; int *const outUnigramProbability) const = 0;
virtual int getTerminalNodePositionOfWord( virtual int getTerminalNodePositionOfWord(

View file

@ -18,10 +18,8 @@
#include "defines.h" #include "defines.h"
#include "jni.h" #include "jni.h"
#include "suggest/core/dicnode/dic_node_utils.h"
#include "suggest/core/dictionary/binary_dictionary_header.h" #include "suggest/core/dictionary/binary_dictionary_header.h"
#include "suggest/core/dictionary/binary_dictionary_info.h" #include "suggest/core/dictionary/binary_dictionary_info.h"
#include "suggest/core/dictionary/binary_format.h"
#include "suggest/core/dictionary/dictionary.h" #include "suggest/core/dictionary/dictionary.h"
namespace latinime { namespace latinime {
@ -29,23 +27,22 @@ namespace latinime {
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
int prevWordLength, const SuggestOptions *const suggestOptions) { int prevWordLength, const SuggestOptions *const suggestOptions) {
mDictionary = dictionary; mDictionary = dictionary;
mMultiWordCostMultiplier = mDictionary->getBinaryDictionaryInfo() const BinaryDictionaryInfo *const binaryDictionaryInfo =
->getHeader()->getMultiWordCostMultiplier(); mDictionary->getBinaryDictionaryInfo();
mMultiWordCostMultiplier = binaryDictionaryInfo->getHeader()->getMultiWordCostMultiplier();
mSuggestOptions = suggestOptions; mSuggestOptions = suggestOptions;
if (!prevWord) { if (!prevWord) {
mPrevWordPos = NOT_VALID_WORD; mPrevWordPos = NOT_VALID_WORD;
return; return;
} }
// TODO: merge following similar calls to getTerminalPosition into one case-insensitive call. // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call.
mPrevWordPos = BinaryFormat::getTerminalPosition( mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord, binaryDictionaryInfo, prevWord, prevWordLength, false /* forceLowerCaseSearch */);
prevWordLength, false /* forceLowerCaseSearch */);
if (mPrevWordPos == NOT_VALID_WORD) { if (mPrevWordPos == NOT_VALID_WORD) {
// Check bigrams for lower-cased previous word if original was not found. Useful for // Check bigrams for lower-cased previous word if original was not found. Useful for
// auto-capitalized words like "The [current_word]". // auto-capitalized words like "The [current_word]".
mPrevWordPos = BinaryFormat::getTerminalPosition( mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord, binaryDictionaryInfo, prevWord, prevWordLength, true /* forceLowerCaseSearch */);
prevWordLength, true /* forceLowerCaseSearch */);
} }
} }

View file

@ -33,11 +33,13 @@ void PatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode,
// TODO: Move children creating methods form DicNodeUtils. // TODO: Move children creating methods form DicNodeUtils.
} }
void PatriciaTriePolicy::getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo, int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
const int terminalNodePos, const int maxDepth, int *const outWord, const BinaryDictionaryInfo *const binaryDictionaryInfo,
const int nodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const { int *const outUnigramProbability) const {
BinaryFormat::getWordAtAddress(binaryDictionaryInfo->getDictRoot(), terminalNodePos, return BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount(
maxDepth, outWord, outUnigramProbability); binaryDictionaryInfo->getDictRoot(), nodePos,
maxCodePointCount, outCodePoints, outUnigramProbability);
} }
int PatriciaTriePolicy::getTerminalNodePositionOfWord( int PatriciaTriePolicy::getTerminalNodePositionOfWord(

View file

@ -36,8 +36,9 @@ class PatriciaTriePolicy : public DictionaryStructurePolicy {
const BinaryDictionaryInfo *const binaryDictionaryInfo, const BinaryDictionaryInfo *const binaryDictionaryInfo,
const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const; const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const;
void getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo, int getCodePointsAndProbabilityAndReturnCodePointCount(
const int terminalNodePos, const int maxDepth, int *const outWord, const BinaryDictionaryInfo *const binaryDictionaryInfo,
const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const; int *const outUnigramProbability) const;
int getTerminalNodePositionOfWord( int getTerminalNodePositionOfWord(