Merge "Refactoring method to get code points and code point count."

This commit is contained in:
Keisuke Kuroyanagi 2014-09-24 07:24:47 +00:00 committed by Android (Google) Code Review
commit 7313b0debe
10 changed files with 47 additions and 61 deletions

View file

@ -77,10 +77,8 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
return; return;
} }
int targetWordCodePoints[MAX_WORD_LENGTH]; int targetWordCodePoints[MAX_WORD_LENGTH];
int unigramProbability = 0; const int codePointCount = mDictStructurePolicy->getCodePointsAndReturnCodePointCount(
const int codePointCount = mDictStructurePolicy-> targetWordId, MAX_WORD_LENGTH, targetWordCodePoints);
getCodePointsAndProbabilityAndReturnCodePointCount(targetWordId, MAX_WORD_LENGTH,
targetWordCodePoints, &unigramProbability);
if (codePointCount <= 0) { if (codePointCount <= 0) {
return; return;
} }

View file

@ -51,9 +51,8 @@ class DictionaryStructureWithBufferPolicy {
virtual void createAndGetAllChildDicNodes(const DicNode *const dicNode, virtual void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const = 0; DicNodeVector *const childDicNodes) const = 0;
virtual int getCodePointsAndProbabilityAndReturnCodePointCount( virtual int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outCodePoints) const = 0;
int *const outUnigramProbability) const = 0;
virtual int getWordId(const CodePointArrayView wordCodePoints, virtual int getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const = 0; const bool forceLowerCaseSearch) const = 0;

View file

@ -87,14 +87,13 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
} }
} }
int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
const int wordId, const int maxCodePointCount, int *const outCodePoints, const int maxCodePointCount, int *const outCodePoints) const {
int *const outUnigramProbability) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
readingHelper.initWithPtNodePos(ptNodePos); readingHelper.initWithPtNodePos(ptNodePos);
const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount( const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount(
maxCodePointCount, outCodePoints, outUnigramProbability); maxCodePointCount, outCodePoints);
if (readingHelper.isError()) { if (readingHelper.isError()) {
mIsCorrupted = true; mIsCorrupted = true;
AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount()."); AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount().");
@ -521,11 +520,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
if (word1TerminalPtNodePos == NOT_A_DICT_POS) { if (word1TerminalPtNodePos == NOT_A_DICT_POS) {
continue; continue;
} }
// Word (unigram) probability const int codePointCount = getCodePointsAndReturnCodePointCount(
int word1Probability = NOT_A_PROBABILITY;
const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH, getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
bigramWord1CodePoints, &word1Probability); bigramWord1CodePoints);
const std::vector<int> word1(bigramWord1CodePoints, const std::vector<int> word1(bigramWord1CodePoints,
bigramWord1CodePoints + codePointCount); bigramWord1CodePoints + codePointCount);
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
@ -580,10 +577,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
return 0; return 0;
} }
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
int unigramProbability = NOT_A_PROBABILITY; *outCodePointCount = getCodePointsAndReturnCodePointCount(
*outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints);
getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints,
&unigramProbability);
const int nextToken = token + 1; const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) { if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated. // All words have been iterated.

View file

@ -85,9 +85,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
void createAndGetAllChildDicNodes(const DicNode *const dicNode, void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const; DicNodeVector *const childDicNodes) const;
int getCodePointsAndProbabilityAndReturnCodePointCount( int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outCodePoints) const;
int *const outUnigramProbability) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;

View file

@ -175,8 +175,8 @@ bool DynamicPtReadingHelper::traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFi
return !isError(); return !isError();
} }
int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( int DynamicPtReadingHelper::getCodePointsAndReturnCodePointCount(const int maxCodePointCount,
const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) { int *const outCodePoints) {
// This method traverses parent nodes from the terminal by following parent pointers; thus, // This method traverses parent nodes from the terminal by following parent pointers; thus,
// node code points are stored in the buffer in the reverse order. // node code points are stored in the buffer in the reverse order.
int reverseCodePoints[maxCodePointCount]; int reverseCodePoints[maxCodePointCount];
@ -184,11 +184,8 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
// First, read the terminal node and get its probability. // First, read the terminal node and get its probability.
if (!isValidTerminalNode(terminalPtNodeParams)) { if (!isValidTerminalNode(terminalPtNodeParams)) {
// Node at the ptNodePos is not a valid terminal node. // Node at the ptNodePos is not a valid terminal node.
*outUnigramProbability = NOT_A_PROBABILITY;
return 0; return 0;
} }
// Store terminal node probability.
*outUnigramProbability = terminalPtNodeParams.getProbability();
// Then, following parent node link to the dictionary root and fetch node code points. // Then, following parent node link to the dictionary root and fetch node code points.
int totalCodePointCount = 0; int totalCodePointCount = 0;
while (!isEnd()) { while (!isEnd()) {
@ -196,7 +193,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
totalCodePointCount = getTotalCodePointCount(ptNodeParams); totalCodePointCount = getTotalCodePointCount(ptNodeParams);
if (!ptNodeParams.isValid() || totalCodePointCount > maxCodePointCount) { if (!ptNodeParams.isValid() || totalCodePointCount > maxCodePointCount) {
// The ptNodePos is not a valid terminal node position in the dictionary. // The ptNodePos is not a valid terminal node position in the dictionary.
*outUnigramProbability = NOT_A_PROBABILITY;
return 0; return 0;
} }
// Store node code points to buffer in the reverse order. // Store node code points to buffer in the reverse order.
@ -207,7 +203,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
} }
if (isError()) { if (isError()) {
// The node position or the dictionary is invalid. // The node position or the dictionary is invalid.
*outUnigramProbability = NOT_A_PROBABILITY;
return 0; return 0;
} }
// Reverse the stored code points to output them. // Reverse the stored code points to output them.

View file

@ -211,8 +211,7 @@ class DynamicPtReadingHelper {
bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner(
TraversingEventListener *const listener); TraversingEventListener *const listener);
int getCodePointsAndProbabilityAndReturnCodePointCount(const int maxCodePointCount, int getCodePointsAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints);
int *const outCodePoints, int *const outUnigramProbability);
int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length, int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length,
const bool forceLowerCaseSearch); const bool forceLowerCaseSearch);

View file

@ -58,6 +58,11 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo
} }
} }
int PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
const int maxCodePointCount, int *const outCodePoints) const {
return getCodePointsAndProbabilityAndReturnCodePointCount(wordId, maxCodePointCount,
outCodePoints, nullptr /* outUnigramProbability */);
}
// This retrieves code points and the probability of the word by its id. // This retrieves code points and the probability of the word by its id.
// Due to the fact that words are ordered in the dictionary in a strict breadth-first order, // Due to the fact that words are ordered in the dictionary in a strict breadth-first order,
// it is possible to check for this with advantageous complexity. For each PtNode array, we search // it is possible to check for this with advantageous complexity. For each PtNode array, we search
@ -82,6 +87,9 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
int pos = getRootPosition(); int pos = getRootPosition();
int wordPos = 0; int wordPos = 0;
const int *const codePointTable = mHeaderPolicy.getCodePointTable(); const int *const codePointTable = mHeaderPolicy.getCodePointTable();
if (outUnigramProbability) {
*outUnigramProbability = NOT_A_PROBABILITY;
}
// One iteration of the outer loop iterates through PtNode arrays. As stated above, we will // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will
// only traverse PtNodes that are actually a part of the terminal we are searching, so each // only traverse PtNodes that are actually a part of the terminal we are searching, so each
// time we enter this loop we are one depth level further than last time. // time we enter this loop we are one depth level further than last time.
@ -97,7 +105,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
pos, mBuffer.size()); pos, mBuffer.size());
mIsCorrupted = true; mIsCorrupted = true;
ASSERT(false); ASSERT(false);
*outUnigramProbability = NOT_A_PROBABILITY;
return 0; return 0;
} }
for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
@ -107,7 +114,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size()); AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size());
mIsCorrupted = true; mIsCorrupted = true;
ASSERT(false); ASSERT(false);
*outUnigramProbability = NOT_A_PROBABILITY;
return 0; return 0;
} }
const PatriciaTrieReadingUtils::NodeFlags flags = const PatriciaTrieReadingUtils::NodeFlags flags =
@ -130,9 +136,11 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
mBuffer.data(), codePointTable, &pos); mBuffer.data(), codePointTable, &pos);
} }
} }
*outUnigramProbability = if (outUnigramProbability) {
PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), *outUnigramProbability =
&pos); PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(
mBuffer.data(), &pos);
}
return ++wordPos; return ++wordPos;
} }
// We need to skip past this PtNode, so skip any remaining code points after the // We need to skip past this PtNode, so skip any remaining code points after the
@ -234,7 +242,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
pos); pos);
mIsCorrupted = true; mIsCorrupted = true;
ASSERT(false); ASSERT(false);
*outUnigramProbability = NOT_A_PROBABILITY;
return 0; return 0;
} }
} }
@ -257,7 +264,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos); AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos);
mIsCorrupted = true; mIsCorrupted = true;
ASSERT(false); ASSERT(false);
*outUnigramProbability = NOT_A_PROBABILITY;
return 0; return 0;
} }
} }
@ -497,10 +503,8 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC
return 0; return 0;
} }
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
int unigramProbability = NOT_A_PROBABILITY; *outCodePointCount = getCodePointsAndReturnCodePointCount(
*outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints);
getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints,
&unigramProbability);
const int nextToken = token + 1; const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) { if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated. // All words have been iterated.

View file

@ -58,9 +58,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
void createAndGetAllChildDicNodes(const DicNode *const dicNode, void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const; DicNodeVector *const childDicNodes) const;
int getCodePointsAndProbabilityAndReturnCodePointCount( int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outCodePoints) const;
int *const outUnigramProbability) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
@ -155,6 +154,9 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
std::vector<int> mTerminalPtNodePositionsForIteratingWords; std::vector<int> mTerminalPtNodePositionsForIteratingWords;
mutable bool mIsCorrupted; mutable bool mIsCorrupted;
int getCodePointsAndProbabilityAndReturnCodePointCount(const int wordId,
const int maxCodePointCount, int *const outCodePoints,
int *const outUnigramProbability) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const; int getShortcutPositionOfPtNode(const int ptNodePos) const;
int getBigramsPositionOfPtNode(const int ptNodePos) const; int getBigramsPositionOfPtNode(const int ptNodePos) const;
int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos, int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos,

View file

@ -74,15 +74,14 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
} }
} }
int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
const int wordId, const int maxCodePointCount, int *const outCodePoints, const int maxCodePointCount, int *const outCodePoints) const {
int *const outUnigramProbability) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
const int ptNodePos = const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
readingHelper.initWithPtNodePos(ptNodePos); readingHelper.initWithPtNodePos(ptNodePos);
const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount( const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount(
maxCodePointCount, outCodePoints, outUnigramProbability); maxCodePointCount, outCodePoints);
if (readingHelper.isError()) { if (readingHelper.isError()) {
mIsCorrupted = true; mIsCorrupted = true;
AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount()."); AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount().");
@ -465,10 +464,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
int bigramWord1CodePoints[MAX_WORD_LENGTH]; int bigramWord1CodePoints[MAX_WORD_LENGTH];
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries( for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
prevWordIds)) { prevWordIds)) {
// Word (unigram) probability const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
int word1Probability = NOT_A_PROBABILITY; MAX_WORD_LENGTH, bigramWord1CodePoints);
const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
entry.getWordId(), MAX_WORD_LENGTH, bigramWord1CodePoints, &word1Probability);
const std::vector<int> word1(bigramWord1CodePoints, const std::vector<int> word1(bigramWord1CodePoints,
bigramWord1CodePoints + codePointCount); bigramWord1CodePoints + codePointCount);
const ProbabilityEntry probabilityEntry = entry.getProbabilityEntry(); const ProbabilityEntry probabilityEntry = entry.getProbabilityEntry();
@ -524,9 +521,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
const PtNodeParams ptNodeParams = const PtNodeParams ptNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(terminalPtNodePos); mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(terminalPtNodePos);
int unigramProbability = NOT_A_PROBABILITY; *outCodePointCount = getCodePointsAndReturnCodePointCount(ptNodeParams.getTerminalId(),
*outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( MAX_WORD_LENGTH, outCodePoints);
ptNodeParams.getTerminalId(), MAX_WORD_LENGTH, outCodePoints, &unigramProbability);
const int nextToken = token + 1; const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) { if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated. // All words have been iterated.

View file

@ -62,9 +62,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
void createAndGetAllChildDicNodes(const DicNode *const dicNode, void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const; DicNodeVector *const childDicNodes) const;
int getCodePointsAndProbabilityAndReturnCodePointCount( int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outCodePoints) const;
int *const outUnigramProbability) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;