Merge "Refactoring method to get code points and code point count."
This commit is contained in:
commit
7313b0debe
10 changed files with 47 additions and 61 deletions
|
@ -77,10 +77,8 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
|
|||
return;
|
||||
}
|
||||
int targetWordCodePoints[MAX_WORD_LENGTH];
|
||||
int unigramProbability = 0;
|
||||
const int codePointCount = mDictStructurePolicy->
|
||||
getCodePointsAndProbabilityAndReturnCodePointCount(targetWordId, MAX_WORD_LENGTH,
|
||||
targetWordCodePoints, &unigramProbability);
|
||||
const int codePointCount = mDictStructurePolicy->getCodePointsAndReturnCodePointCount(
|
||||
targetWordId, MAX_WORD_LENGTH, targetWordCodePoints);
|
||||
if (codePointCount <= 0) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -51,9 +51,8 @@ class DictionaryStructureWithBufferPolicy {
|
|||
virtual void createAndGetAllChildDicNodes(const DicNode *const dicNode,
|
||||
DicNodeVector *const childDicNodes) const = 0;
|
||||
|
||||
virtual int getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
const int wordId, const int maxCodePointCount, int *const outCodePoints,
|
||||
int *const outUnigramProbability) const = 0;
|
||||
virtual int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
|
||||
int *const outCodePoints) const = 0;
|
||||
|
||||
virtual int getWordId(const CodePointArrayView wordCodePoints,
|
||||
const bool forceLowerCaseSearch) const = 0;
|
||||
|
|
|
@ -87,14 +87,13 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
|
|||
}
|
||||
}
|
||||
|
||||
int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
const int wordId, const int maxCodePointCount, int *const outCodePoints,
|
||||
int *const outUnigramProbability) const {
|
||||
int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
|
||||
const int maxCodePointCount, int *const outCodePoints) const {
|
||||
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
|
||||
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
|
||||
readingHelper.initWithPtNodePos(ptNodePos);
|
||||
const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
maxCodePointCount, outCodePoints, outUnigramProbability);
|
||||
const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount(
|
||||
maxCodePointCount, outCodePoints);
|
||||
if (readingHelper.isError()) {
|
||||
mIsCorrupted = true;
|
||||
AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount().");
|
||||
|
@ -521,11 +520,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
|
|||
if (word1TerminalPtNodePos == NOT_A_DICT_POS) {
|
||||
continue;
|
||||
}
|
||||
// Word (unigram) probability
|
||||
int word1Probability = NOT_A_PROBABILITY;
|
||||
const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
const int codePointCount = getCodePointsAndReturnCodePointCount(
|
||||
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
|
||||
bigramWord1CodePoints, &word1Probability);
|
||||
bigramWord1CodePoints);
|
||||
const std::vector<int> word1(bigramWord1CodePoints,
|
||||
bigramWord1CodePoints + codePointCount);
|
||||
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
|
||||
|
@ -580,10 +577,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
|
|||
return 0;
|
||||
}
|
||||
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
|
||||
int unigramProbability = NOT_A_PROBABILITY;
|
||||
*outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints,
|
||||
&unigramProbability);
|
||||
*outCodePointCount = getCodePointsAndReturnCodePointCount(
|
||||
getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints);
|
||||
const int nextToken = token + 1;
|
||||
if (nextToken >= terminalPtNodePositionsVectorSize) {
|
||||
// All words have been iterated.
|
||||
|
|
|
@ -85,9 +85,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
|||
void createAndGetAllChildDicNodes(const DicNode *const dicNode,
|
||||
DicNodeVector *const childDicNodes) const;
|
||||
|
||||
int getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
const int wordId, const int maxCodePointCount, int *const outCodePoints,
|
||||
int *const outUnigramProbability) const;
|
||||
int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
|
||||
int *const outCodePoints) const;
|
||||
|
||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||
|
||||
|
|
|
@ -175,8 +175,8 @@ bool DynamicPtReadingHelper::traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFi
|
|||
return !isError();
|
||||
}
|
||||
|
||||
int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) {
|
||||
int DynamicPtReadingHelper::getCodePointsAndReturnCodePointCount(const int maxCodePointCount,
|
||||
int *const outCodePoints) {
|
||||
// This method traverses parent nodes from the terminal by following parent pointers; thus,
|
||||
// node code points are stored in the buffer in the reverse order.
|
||||
int reverseCodePoints[maxCodePointCount];
|
||||
|
@ -184,11 +184,8 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
|
|||
// First, read the terminal node and get its probability.
|
||||
if (!isValidTerminalNode(terminalPtNodeParams)) {
|
||||
// Node at the ptNodePos is not a valid terminal node.
|
||||
*outUnigramProbability = NOT_A_PROBABILITY;
|
||||
return 0;
|
||||
}
|
||||
// Store terminal node probability.
|
||||
*outUnigramProbability = terminalPtNodeParams.getProbability();
|
||||
// Then, following parent node link to the dictionary root and fetch node code points.
|
||||
int totalCodePointCount = 0;
|
||||
while (!isEnd()) {
|
||||
|
@ -196,7 +193,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
|
|||
totalCodePointCount = getTotalCodePointCount(ptNodeParams);
|
||||
if (!ptNodeParams.isValid() || totalCodePointCount > maxCodePointCount) {
|
||||
// The ptNodePos is not a valid terminal node position in the dictionary.
|
||||
*outUnigramProbability = NOT_A_PROBABILITY;
|
||||
return 0;
|
||||
}
|
||||
// Store node code points to buffer in the reverse order.
|
||||
|
@ -207,7 +203,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
|
|||
}
|
||||
if (isError()) {
|
||||
// The node position or the dictionary is invalid.
|
||||
*outUnigramProbability = NOT_A_PROBABILITY;
|
||||
return 0;
|
||||
}
|
||||
// Reverse the stored code points to output them.
|
||||
|
|
|
@ -211,8 +211,7 @@ class DynamicPtReadingHelper {
|
|||
bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner(
|
||||
TraversingEventListener *const listener);
|
||||
|
||||
int getCodePointsAndProbabilityAndReturnCodePointCount(const int maxCodePointCount,
|
||||
int *const outCodePoints, int *const outUnigramProbability);
|
||||
int getCodePointsAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints);
|
||||
|
||||
int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length,
|
||||
const bool forceLowerCaseSearch);
|
||||
|
|
|
@ -58,6 +58,11 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo
|
|||
}
|
||||
}
|
||||
|
||||
int PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
|
||||
const int maxCodePointCount, int *const outCodePoints) const {
|
||||
return getCodePointsAndProbabilityAndReturnCodePointCount(wordId, maxCodePointCount,
|
||||
outCodePoints, nullptr /* outUnigramProbability */);
|
||||
}
|
||||
// This retrieves code points and the probability of the word by its id.
|
||||
// Due to the fact that words are ordered in the dictionary in a strict breadth-first order,
|
||||
// it is possible to check for this with advantageous complexity. For each PtNode array, we search
|
||||
|
@ -82,6 +87,9 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
|
|||
int pos = getRootPosition();
|
||||
int wordPos = 0;
|
||||
const int *const codePointTable = mHeaderPolicy.getCodePointTable();
|
||||
if (outUnigramProbability) {
|
||||
*outUnigramProbability = NOT_A_PROBABILITY;
|
||||
}
|
||||
// One iteration of the outer loop iterates through PtNode arrays. As stated above, we will
|
||||
// only traverse PtNodes that are actually a part of the terminal we are searching, so each
|
||||
// time we enter this loop we are one depth level further than last time.
|
||||
|
@ -97,7 +105,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
|
|||
pos, mBuffer.size());
|
||||
mIsCorrupted = true;
|
||||
ASSERT(false);
|
||||
*outUnigramProbability = NOT_A_PROBABILITY;
|
||||
return 0;
|
||||
}
|
||||
for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
|
||||
|
@ -107,7 +114,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
|
|||
AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size());
|
||||
mIsCorrupted = true;
|
||||
ASSERT(false);
|
||||
*outUnigramProbability = NOT_A_PROBABILITY;
|
||||
return 0;
|
||||
}
|
||||
const PatriciaTrieReadingUtils::NodeFlags flags =
|
||||
|
@ -130,9 +136,11 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
|
|||
mBuffer.data(), codePointTable, &pos);
|
||||
}
|
||||
}
|
||||
*outUnigramProbability =
|
||||
PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(),
|
||||
&pos);
|
||||
if (outUnigramProbability) {
|
||||
*outUnigramProbability =
|
||||
PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(
|
||||
mBuffer.data(), &pos);
|
||||
}
|
||||
return ++wordPos;
|
||||
}
|
||||
// We need to skip past this PtNode, so skip any remaining code points after the
|
||||
|
@ -234,7 +242,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
|
|||
pos);
|
||||
mIsCorrupted = true;
|
||||
ASSERT(false);
|
||||
*outUnigramProbability = NOT_A_PROBABILITY;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
@ -257,7 +264,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
|
|||
AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos);
|
||||
mIsCorrupted = true;
|
||||
ASSERT(false);
|
||||
*outUnigramProbability = NOT_A_PROBABILITY;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
@ -497,10 +503,8 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC
|
|||
return 0;
|
||||
}
|
||||
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
|
||||
int unigramProbability = NOT_A_PROBABILITY;
|
||||
*outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints,
|
||||
&unigramProbability);
|
||||
*outCodePointCount = getCodePointsAndReturnCodePointCount(
|
||||
getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints);
|
||||
const int nextToken = token + 1;
|
||||
if (nextToken >= terminalPtNodePositionsVectorSize) {
|
||||
// All words have been iterated.
|
||||
|
|
|
@ -58,9 +58,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
|||
void createAndGetAllChildDicNodes(const DicNode *const dicNode,
|
||||
DicNodeVector *const childDicNodes) const;
|
||||
|
||||
int getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
const int wordId, const int maxCodePointCount, int *const outCodePoints,
|
||||
int *const outUnigramProbability) const;
|
||||
int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
|
||||
int *const outCodePoints) const;
|
||||
|
||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||
|
||||
|
@ -155,6 +154,9 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
|||
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
|
||||
mutable bool mIsCorrupted;
|
||||
|
||||
int getCodePointsAndProbabilityAndReturnCodePointCount(const int wordId,
|
||||
const int maxCodePointCount, int *const outCodePoints,
|
||||
int *const outUnigramProbability) const;
|
||||
int getShortcutPositionOfPtNode(const int ptNodePos) const;
|
||||
int getBigramsPositionOfPtNode(const int ptNodePos) const;
|
||||
int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos,
|
||||
|
|
|
@ -74,15 +74,14 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
|
|||
}
|
||||
}
|
||||
|
||||
int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
const int wordId, const int maxCodePointCount, int *const outCodePoints,
|
||||
int *const outUnigramProbability) const {
|
||||
int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
|
||||
const int maxCodePointCount, int *const outCodePoints) const {
|
||||
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
|
||||
const int ptNodePos =
|
||||
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
|
||||
readingHelper.initWithPtNodePos(ptNodePos);
|
||||
const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
maxCodePointCount, outCodePoints, outUnigramProbability);
|
||||
const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount(
|
||||
maxCodePointCount, outCodePoints);
|
||||
if (readingHelper.isError()) {
|
||||
mIsCorrupted = true;
|
||||
AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount().");
|
||||
|
@ -465,10 +464,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
|
|||
int bigramWord1CodePoints[MAX_WORD_LENGTH];
|
||||
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
|
||||
prevWordIds)) {
|
||||
// Word (unigram) probability
|
||||
int word1Probability = NOT_A_PROBABILITY;
|
||||
const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
entry.getWordId(), MAX_WORD_LENGTH, bigramWord1CodePoints, &word1Probability);
|
||||
const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
|
||||
MAX_WORD_LENGTH, bigramWord1CodePoints);
|
||||
const std::vector<int> word1(bigramWord1CodePoints,
|
||||
bigramWord1CodePoints + codePointCount);
|
||||
const ProbabilityEntry probabilityEntry = entry.getProbabilityEntry();
|
||||
|
@ -524,9 +521,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
|
|||
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
|
||||
const PtNodeParams ptNodeParams =
|
||||
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(terminalPtNodePos);
|
||||
int unigramProbability = NOT_A_PROBABILITY;
|
||||
*outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
ptNodeParams.getTerminalId(), MAX_WORD_LENGTH, outCodePoints, &unigramProbability);
|
||||
*outCodePointCount = getCodePointsAndReturnCodePointCount(ptNodeParams.getTerminalId(),
|
||||
MAX_WORD_LENGTH, outCodePoints);
|
||||
const int nextToken = token + 1;
|
||||
if (nextToken >= terminalPtNodePositionsVectorSize) {
|
||||
// All words have been iterated.
|
||||
|
|
|
@ -62,9 +62,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
|||
void createAndGetAllChildDicNodes(const DicNode *const dicNode,
|
||||
DicNodeVector *const childDicNodes) const;
|
||||
|
||||
int getCodePointsAndProbabilityAndReturnCodePointCount(
|
||||
const int wordId, const int maxCodePointCount, int *const outCodePoints,
|
||||
int *const outUnigramProbability) const;
|
||||
int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
|
||||
int *const outCodePoints) const;
|
||||
|
||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||
|
||||
|
|
Loading…
Reference in a new issue