diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h index b668aab78..71558edaa 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h @@ -42,7 +42,7 @@ class DynamicPatriciaTrieNodeReader { // Reads node information from dictionary buffer and updates members with the information. AK_FORCE_INLINE void fetchNodeInfoFromBuffer(const int nodePos) { - fetchNodeInfoFromBufferAndGetNodeCodePoints(mNodePos , 0 /* maxCodePointCount */, + fetchNodeInfoFromBufferAndGetNodeCodePoints(nodePos , 0 /* maxCodePointCount */, 0 /* outCodePoints */); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp index 9a180e6f7..0b73efae3 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp @@ -27,6 +27,8 @@ namespace latinime { const DynamicPatriciaTriePolicy DynamicPatriciaTriePolicy::sInstance; +// To avoid infinite loop caused by invalid or malicious forward links. +const int DynamicPatriciaTriePolicy::MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP = 100000; void DynamicPatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, const BinaryDictionaryInfo *const binaryDictionaryInfo, @@ -37,14 +39,23 @@ void DynamicPatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const d DynamicPatriciaTrieNodeReader nodeReader(binaryDictionaryInfo); int mergedNodeCodePoints[MAX_WORD_LENGTH]; int nextPos = dicNode->getChildrenPos(); + int totalChildCount = 0; do { const int childCount = PatriciaTrieReadingUtils::getGroupCountAndAdvancePosition( binaryDictionaryInfo->getDictRoot(), &nextPos); + totalChildCount += childCount; + if (childCount <= 0 || totalChildCount > MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP) { + // Invalid dictionary. + AKLOGI("Invalid dictionary. childCount: %d, totalChildCount: %d, MAX: %d", + childCount, totalChildCount, MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP); + ASSERT(false); + return; + } for (int i = 0; i < childCount; i++) { nodeReader.fetchNodeInfoFromBufferAndGetNodeCodePoints(nextPos, MAX_WORD_LENGTH, mergedNodeCodePoints); if (!nodeReader.isDeleted() && !nodeFilter->isFilteredOut(mergedNodeCodePoints[0])) { - // Push child note when the node is not deleted and not filtered out. + // Push child node when the node is not deleted and not filtered out. childDicNodes->pushLeavingChild(dicNode, nodeReader.getNodePos(), nodeReader.getChildrenPos(), nodeReader.getProbability(), nodeReader.isTerminal(), nodeReader.hasChildren(), @@ -55,13 +66,17 @@ void DynamicPatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const d } nextPos = DynamicPatriciaTrieReadingUtils::getForwardLinkPosition( binaryDictionaryInfo->getDictRoot(), nextPos); - } while(DynamicPatriciaTrieReadingUtils::isValidForwardLinkPosition(nextPos)); + } while (DynamicPatriciaTrieReadingUtils::isValidForwardLinkPosition(nextPos)); } int DynamicPatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const { + if (nodePos == NOT_A_VALID_WORD_POS) { + *outUnigramProbability = NOT_A_PROBABILITY; + return 0; + } // This method traverses parent nodes from the terminal by following parent pointers; thus, // node code points are stored in the buffer in the reverse order. int reverseCodePoints[maxCodePointCount]; @@ -106,12 +121,85 @@ int DynamicPatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCoun int DynamicPatriciaTriePolicy::getTerminalNodePositionOfWord( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int *const inWord, const int length, const bool forceLowerCaseSearch) const { - // TODO: Implement. - return NOT_A_DICT_POS; + int searchCodePoints[length]; + for (int i = 0; i < length; ++i) { + searchCodePoints[i] = forceLowerCaseSearch ? CharUtils::toLowerCase(inWord[i]) : inWord[i]; + } + int mergedNodeCodePoints[MAX_WORD_LENGTH]; + int currentLength = 0; + int pos = getRootPosition(); + DynamicPatriciaTrieNodeReader nodeReader(binaryDictionaryInfo); + while (currentLength <= length) { + // When foundMatchedNode becomes true, currentLength is increased at least once. + bool foundMatchedNode = false; + int totalChildCount = 0; + do { + const int childCount = PatriciaTrieReadingUtils::getGroupCountAndAdvancePosition( + binaryDictionaryInfo->getDictRoot(), &pos); + totalChildCount += childCount; + if (childCount <= 0 || totalChildCount > MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP) { + // Invalid dictionary. + AKLOGI("Invalid dictionary. childCount: %d, totalChildCount: %d, MAX: %d", + childCount, totalChildCount, MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP); + ASSERT(false); + return NOT_A_VALID_WORD_POS; + } + for (int i = 0; i < childCount; i++) { + nodeReader.fetchNodeInfoFromBufferAndGetNodeCodePoints(pos, MAX_WORD_LENGTH, + mergedNodeCodePoints); + if (nodeReader.isDeleted() || nodeReader.getCodePointCount() <= 0) { + // Skip deleted or empty node. + pos = nodeReader.getSiblingNodePos(); + continue; + } + bool matched = true; + for (int j = 0; j < nodeReader.getCodePointCount(); ++j) { + if (mergedNodeCodePoints[j] != searchCodePoints[currentLength + j]) { + // Different code point is found. + matched = false; + break; + } + } + if (matched) { + currentLength += nodeReader.getCodePointCount(); + if (length == currentLength) { + // Terminal position is found. + return nodeReader.getNodePos(); + } + if (!nodeReader.hasChildren()) { + return NOT_A_VALID_WORD_POS; + } + foundMatchedNode = true; + // Advance to the children nodes. + pos = nodeReader.getChildrenPos(); + break; + } + // Try next sibling node. + pos = nodeReader.getSiblingNodePos(); + } + if (foundMatchedNode) { + break; + } + // If the matched node is not found in the current node group, try to follow the + // forward link. + pos = DynamicPatriciaTrieReadingUtils::getForwardLinkPosition( + binaryDictionaryInfo->getDictRoot(), pos); + } while (DynamicPatriciaTrieReadingUtils::isValidForwardLinkPosition(pos)); + if (!foundMatchedNode) { + // Matched node is not found. + return NOT_A_VALID_WORD_POS; + } + } + // If we already traversed the tree further than the word is long, there means + // there was no match (or we would have found it). + return NOT_A_VALID_WORD_POS; } int DynamicPatriciaTriePolicy::getUnigramProbability( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) const { + if (nodePos == NOT_A_VALID_WORD_POS) { + return NOT_A_PROBABILITY; + } DynamicPatriciaTrieNodeReader nodeReader(binaryDictionaryInfo); nodeReader.fetchNodeInfoFromBuffer(nodePos); if (nodeReader.isDeleted() || nodeReader.isBlacklisted() || nodeReader.isNotAWord()) { @@ -123,6 +211,9 @@ int DynamicPatriciaTriePolicy::getUnigramProbability( int DynamicPatriciaTriePolicy::getShortcutPositionOfNode( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) const { + if (nodePos == NOT_A_VALID_WORD_POS) { + return NOT_A_DICT_POS; + } DynamicPatriciaTrieNodeReader nodeReader(binaryDictionaryInfo); nodeReader.fetchNodeInfoFromBuffer(nodePos); if (nodeReader.isDeleted()) { @@ -134,6 +225,9 @@ int DynamicPatriciaTriePolicy::getShortcutPositionOfNode( int DynamicPatriciaTriePolicy::getBigramsPositionOfNode( const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) const { + if (nodePos == NOT_A_VALID_WORD_POS) { + return NOT_A_DICT_POS; + } DynamicPatriciaTrieNodeReader nodeReader(binaryDictionaryInfo); nodeReader.fetchNodeInfoFromBuffer(nodePos); if (nodeReader.isDeleted()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h index 39dfb86fd..6a7977138 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h @@ -61,6 +61,7 @@ class DynamicPatriciaTriePolicy : public DictionaryStructurePolicy { private: DISALLOW_COPY_AND_ASSIGN(DynamicPatriciaTriePolicy); static const DynamicPatriciaTriePolicy sInstance; + static const int MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP; DynamicPatriciaTriePolicy() {} ~DynamicPatriciaTriePolicy() {}