am b6f286bf
: Merge "Make bigram dictionary and traverse session use structure policy."
* commit 'b6f286bfa549ed91c67d591fc1725e35b114742b': Make bigram dictionary and traverse session use structure policy.
This commit is contained in:
commit
df266ac7aa
6 changed files with 40 additions and 46 deletions
|
@ -123,9 +123,10 @@ int BigramDictionary::getPredictions(const int *prevWord, int prevWordLength, in
|
||||||
for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos);
|
for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos);
|
||||||
bigramsIt.hasNext(); /* no-op */) {
|
bigramsIt.hasNext(); /* no-op */) {
|
||||||
bigramsIt.next();
|
bigramsIt.next();
|
||||||
const int length = BinaryFormat::getWordAtAddress(
|
const int length = mBinaryDictionaryInfo->getStructurePolicy()->
|
||||||
mBinaryDictionaryInfo->getDictRoot(), bigramsIt.getBigramPos(),
|
getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||||
MAX_WORD_LENGTH, bigramBuffer, &unigramProbability);
|
mBinaryDictionaryInfo, bigramsIt.getBigramPos(), MAX_WORD_LENGTH,
|
||||||
|
bigramBuffer, &unigramProbability);
|
||||||
|
|
||||||
// inputSize == 0 means we are trying to find bigram predictions.
|
// inputSize == 0 means we are trying to find bigram predictions.
|
||||||
if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) {
|
if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) {
|
||||||
|
@ -153,18 +154,8 @@ int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const in
|
||||||
int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
|
int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
|
||||||
mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch);
|
mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch);
|
||||||
if (NOT_VALID_WORD == pos) return 0;
|
if (NOT_VALID_WORD == pos) return 0;
|
||||||
const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot();
|
return BinaryFormat::getBigramListPositionForWordPosition(
|
||||||
const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
|
mBinaryDictionaryInfo->getDictRoot(), pos);
|
||||||
if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) return 0;
|
|
||||||
if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) {
|
|
||||||
BinaryFormat::getCodePointAndForwardPointer(root, &pos);
|
|
||||||
} else {
|
|
||||||
pos = BinaryFormat::skipOtherCharacters(root, pos);
|
|
||||||
}
|
|
||||||
pos = BinaryFormat::skipProbability(flags, pos);
|
|
||||||
pos = BinaryFormat::skipChildrenPosition(flags, pos);
|
|
||||||
pos = BinaryFormat::skipShortcuts(root, flags, pos);
|
|
||||||
return pos;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const {
|
bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const {
|
||||||
|
|
|
@ -71,8 +71,9 @@ class BinaryFormat {
|
||||||
static bool hasChildrenInFlags(const uint8_t flags);
|
static bool hasChildrenInFlags(const uint8_t flags);
|
||||||
static int getTerminalPosition(const uint8_t *const root, const int *const inWord,
|
static int getTerminalPosition(const uint8_t *const root, const int *const inWord,
|
||||||
const int length, const bool forceLowerCaseSearch);
|
const int length, const bool forceLowerCaseSearch);
|
||||||
static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth,
|
static int getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||||
int *outWord, int *outUnigramProbability);
|
const uint8_t *const root, const int nodePos, const int maxCodePointCount,
|
||||||
|
int *outCodePoints, int *outUnigramProbability);
|
||||||
static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);
|
static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -342,8 +343,9 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root,
|
||||||
* outUnigramProbability: a pointer to an int to write the probability into.
|
* outUnigramProbability: a pointer to an int to write the probability into.
|
||||||
* Return value : the length of the word, of 0 if the word was not found.
|
* Return value : the length of the word, of 0 if the word was not found.
|
||||||
*/
|
*/
|
||||||
AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address,
|
AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||||
const int maxDepth, int *outWord, int *outUnigramProbability) {
|
const uint8_t *const root, const int nodePos,
|
||||||
|
const int maxCodePointCount, int *outCodePoints, int *outUnigramProbability) {
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
int wordPos = 0;
|
int wordPos = 0;
|
||||||
|
|
||||||
|
@ -353,7 +355,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
|
||||||
// The only reason we count nodes is because we want to reduce the probability of infinite
|
// The only reason we count nodes is because we want to reduce the probability of infinite
|
||||||
// looping in case there is a bug. Since we know there is an upper bound to the depth we are
|
// looping in case there is a bug. Since we know there is an upper bound to the depth we are
|
||||||
// supposed to traverse, it does not hurt to count iterations.
|
// supposed to traverse, it does not hurt to count iterations.
|
||||||
for (int loopCount = maxDepth; loopCount > 0; --loopCount) {
|
for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) {
|
||||||
int lastCandidateGroupPos = 0;
|
int lastCandidateGroupPos = 0;
|
||||||
// Let's loop through char groups in this node searching for either the terminal
|
// Let's loop through char groups in this node searching for either the terminal
|
||||||
// or one of its ascendants.
|
// or one of its ascendants.
|
||||||
|
@ -362,17 +364,17 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
|
||||||
const int startPos = pos;
|
const int startPos = pos;
|
||||||
const uint8_t flags = getFlagsAndForwardPointer(root, &pos);
|
const uint8_t flags = getFlagsAndForwardPointer(root, &pos);
|
||||||
const int character = getCodePointAndForwardPointer(root, &pos);
|
const int character = getCodePointAndForwardPointer(root, &pos);
|
||||||
if (address == startPos) {
|
if (nodePos == startPos) {
|
||||||
// We found the address. Copy the rest of the word in the buffer and return
|
// We found the address. Copy the rest of the word in the buffer and return
|
||||||
// the length.
|
// the length.
|
||||||
outWord[wordPos] = character;
|
outCodePoints[wordPos] = character;
|
||||||
if (FLAG_HAS_MULTIPLE_CHARS & flags) {
|
if (FLAG_HAS_MULTIPLE_CHARS & flags) {
|
||||||
int nextChar = getCodePointAndForwardPointer(root, &pos);
|
int nextChar = getCodePointAndForwardPointer(root, &pos);
|
||||||
// We count chars in order to avoid infinite loops if the file is broken or
|
// We count chars in order to avoid infinite loops if the file is broken or
|
||||||
// if there is some other bug
|
// if there is some other bug
|
||||||
int charCount = maxDepth;
|
int charCount = maxCodePointCount;
|
||||||
while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
|
while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
|
||||||
outWord[++wordPos] = nextChar;
|
outCodePoints[++wordPos] = nextChar;
|
||||||
nextChar = getCodePointAndForwardPointer(root, &pos);
|
nextChar = getCodePointAndForwardPointer(root, &pos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -399,7 +401,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
|
||||||
if (hasChildren) {
|
if (hasChildren) {
|
||||||
// Here comes the tricky part. First, read the children position.
|
// Here comes the tricky part. First, read the children position.
|
||||||
const int childrenPos = readChildrenPosition(root, flags, pos);
|
const int childrenPos = readChildrenPosition(root, flags, pos);
|
||||||
if (childrenPos > address) {
|
if (childrenPos > nodePos) {
|
||||||
// If the children pos is greater than address, it means the previous chargroup,
|
// If the children pos is greater than address, it means the previous chargroup,
|
||||||
// which address is stored in lastCandidateGroupPos, was the right one.
|
// which address is stored in lastCandidateGroupPos, was the right one.
|
||||||
found = true;
|
found = true;
|
||||||
|
@ -429,12 +431,12 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co
|
||||||
const int lastChar =
|
const int lastChar =
|
||||||
getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
|
getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
|
||||||
// We copy all the characters in this group to the buffer
|
// We copy all the characters in this group to the buffer
|
||||||
outWord[wordPos] = lastChar;
|
outCodePoints[wordPos] = lastChar;
|
||||||
if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) {
|
if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) {
|
||||||
int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
|
int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
|
||||||
int charCount = maxDepth;
|
int charCount = maxCodePointCount;
|
||||||
while (-1 != nextChar && --charCount > 0) {
|
while (-1 != nextChar && --charCount > 0) {
|
||||||
outWord[++wordPos] = nextChar;
|
outCodePoints[++wordPos] = nextChar;
|
||||||
nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
|
nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,8 +50,9 @@ class DictionaryStructurePolicy {
|
||||||
const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
||||||
const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const = 0;
|
const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const = 0;
|
||||||
|
|
||||||
virtual void getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
virtual int getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||||
const int terminalNodePos, const int maxDepth, int *const outWord,
|
const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
||||||
|
const int nodePos, const int maxCodePointCount, int *const outCodePoints,
|
||||||
int *const outUnigramProbability) const = 0;
|
int *const outUnigramProbability) const = 0;
|
||||||
|
|
||||||
virtual int getTerminalNodePositionOfWord(
|
virtual int getTerminalNodePositionOfWord(
|
||||||
|
|
|
@ -18,10 +18,8 @@
|
||||||
|
|
||||||
#include "defines.h"
|
#include "defines.h"
|
||||||
#include "jni.h"
|
#include "jni.h"
|
||||||
#include "suggest/core/dicnode/dic_node_utils.h"
|
|
||||||
#include "suggest/core/dictionary/binary_dictionary_header.h"
|
#include "suggest/core/dictionary/binary_dictionary_header.h"
|
||||||
#include "suggest/core/dictionary/binary_dictionary_info.h"
|
#include "suggest/core/dictionary/binary_dictionary_info.h"
|
||||||
#include "suggest/core/dictionary/binary_format.h"
|
|
||||||
#include "suggest/core/dictionary/dictionary.h"
|
#include "suggest/core/dictionary/dictionary.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
@ -29,23 +27,22 @@ namespace latinime {
|
||||||
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
|
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
|
||||||
int prevWordLength, const SuggestOptions *const suggestOptions) {
|
int prevWordLength, const SuggestOptions *const suggestOptions) {
|
||||||
mDictionary = dictionary;
|
mDictionary = dictionary;
|
||||||
mMultiWordCostMultiplier = mDictionary->getBinaryDictionaryInfo()
|
const BinaryDictionaryInfo *const binaryDictionaryInfo =
|
||||||
->getHeader()->getMultiWordCostMultiplier();
|
mDictionary->getBinaryDictionaryInfo();
|
||||||
|
mMultiWordCostMultiplier = binaryDictionaryInfo->getHeader()->getMultiWordCostMultiplier();
|
||||||
mSuggestOptions = suggestOptions;
|
mSuggestOptions = suggestOptions;
|
||||||
if (!prevWord) {
|
if (!prevWord) {
|
||||||
mPrevWordPos = NOT_VALID_WORD;
|
mPrevWordPos = NOT_VALID_WORD;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// TODO: merge following similar calls to getTerminalPosition into one case-insensitive call.
|
// TODO: merge following similar calls to getTerminalPosition into one case-insensitive call.
|
||||||
mPrevWordPos = BinaryFormat::getTerminalPosition(
|
mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
|
||||||
dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord,
|
binaryDictionaryInfo, prevWord, prevWordLength, false /* forceLowerCaseSearch */);
|
||||||
prevWordLength, false /* forceLowerCaseSearch */);
|
|
||||||
if (mPrevWordPos == NOT_VALID_WORD) {
|
if (mPrevWordPos == NOT_VALID_WORD) {
|
||||||
// Check bigrams for lower-cased previous word if original was not found. Useful for
|
// Check bigrams for lower-cased previous word if original was not found. Useful for
|
||||||
// auto-capitalized words like "The [current_word]".
|
// auto-capitalized words like "The [current_word]".
|
||||||
mPrevWordPos = BinaryFormat::getTerminalPosition(
|
mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord(
|
||||||
dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord,
|
binaryDictionaryInfo, prevWord, prevWordLength, true /* forceLowerCaseSearch */);
|
||||||
prevWordLength, true /* forceLowerCaseSearch */);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,11 +33,13 @@ void PatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode,
|
||||||
// TODO: Move children creating methods form DicNodeUtils.
|
// TODO: Move children creating methods form DicNodeUtils.
|
||||||
}
|
}
|
||||||
|
|
||||||
void PatriciaTriePolicy::getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||||
const int terminalNodePos, const int maxDepth, int *const outWord,
|
const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
||||||
|
const int nodePos, const int maxCodePointCount, int *const outCodePoints,
|
||||||
int *const outUnigramProbability) const {
|
int *const outUnigramProbability) const {
|
||||||
BinaryFormat::getWordAtAddress(binaryDictionaryInfo->getDictRoot(), terminalNodePos,
|
return BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||||
maxDepth, outWord, outUnigramProbability);
|
binaryDictionaryInfo->getDictRoot(), nodePos,
|
||||||
|
maxCodePointCount, outCodePoints, outUnigramProbability);
|
||||||
}
|
}
|
||||||
|
|
||||||
int PatriciaTriePolicy::getTerminalNodePositionOfWord(
|
int PatriciaTriePolicy::getTerminalNodePositionOfWord(
|
||||||
|
|
|
@ -36,8 +36,9 @@ class PatriciaTriePolicy : public DictionaryStructurePolicy {
|
||||||
const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
||||||
const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const;
|
const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const;
|
||||||
|
|
||||||
void getWordAtPosition(const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
int getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||||
const int terminalNodePos, const int maxDepth, int *const outWord,
|
const BinaryDictionaryInfo *const binaryDictionaryInfo,
|
||||||
|
const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints,
|
||||||
int *const outUnigramProbability) const;
|
int *const outUnigramProbability) const;
|
||||||
|
|
||||||
int getTerminalNodePositionOfWord(
|
int getTerminalNodePositionOfWord(
|
||||||
|
|
Loading…
Reference in a new issue