diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp index a7d86f9ae..c70047638 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp @@ -98,6 +98,43 @@ bool TrieMap::put(const int key, const uint64_t value, const int bitmapEntryInde return putInternal(unsignedKey, value, getBitShuffledKey(unsignedKey), bitmapEntryIndex, readEntry(bitmapEntryIndex), 0 /* level */); } +/** + * Iterate next entry in a certain level. + * + * @param iterationState the iteration state that will be read and updated in this method. + * @param outKey the output key + * @return Result instance. mIsValid is false when all entries are iterated. + */ +const TrieMap::Result TrieMap::iterateNext(std::vector *const iterationState, + int *const outKey) const { + while (!iterationState->empty()) { + TableIterationState &state = iterationState->back(); + if (state.mTableSize <= state.mCurrentIndex) { + // Move to parent. + iterationState->pop_back(); + } else { + const int entryIndex = state.mTableIndex + state.mCurrentIndex; + state.mCurrentIndex += 1; + const Entry entry = readEntry(entryIndex); + if (entry.isBitmapEntry()) { + // Move to child. + iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex()); + } else { + if (outKey) { + *outKey = entry.getKey(); + } + if (!entry.hasTerminalLink()) { + return Result(entry.getValue(), true, INVALID_INDEX); + } + const int valueEntryIndex = entry.getValueEntryIndex(); + const Entry valueEntry = readEntry(valueEntryIndex); + return Result(valueEntry.getValueOfValueEntry(), true, valueEntryIndex + 1); + } + } + } + // Visited all entries. + return Result(0, false, INVALID_INDEX); +} /** * Shuffle bits of the key in the fixed order. diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h index 2a9051f98..b5bcc3bc8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h @@ -44,6 +44,117 @@ class TrieMap { mNextLevelBitmapEntryIndex(nextLevelBitmapEntryIndex) {} }; + /** + * Struct to record iteration state in a table. + */ + struct TableIterationState { + int mTableSize; + int mTableIndex; + int mCurrentIndex; + + TableIterationState(const int tableSize, const int tableIndex) + : mTableSize(tableSize), mTableIndex(tableIndex), mCurrentIndex(0) {} + }; + + class TrieMapRange; + class TrieMapIterator { + public: + class IterationResult { + public: + IterationResult(const TrieMap *const trieMap, const int key, const uint64_t value, + const int nextLeveBitmapEntryIndex) + : mTrieMap(trieMap), mKey(key), mValue(value), + mNextLevelBitmapEntryIndex(nextLeveBitmapEntryIndex) {} + + const TrieMapRange getEntriesInNextLevel() const { + return TrieMapRange(mTrieMap, mNextLevelBitmapEntryIndex); + } + + bool hasNextLevelMap() const { + return mNextLevelBitmapEntryIndex != INVALID_INDEX; + } + + AK_FORCE_INLINE int key() const { + return mKey; + } + + AK_FORCE_INLINE uint64_t value() const { + return mValue; + } + + private: + const TrieMap *const mTrieMap; + const int mKey; + const uint64_t mValue; + const int mNextLevelBitmapEntryIndex; + }; + + TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex) + : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex), + mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) { + if (!trieMap) { + return; + } + const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex); + mStateStack.emplace_back( + mTrieMap->popCount(bitmapEntry.getBitmap()), bitmapEntry.getTableIndex()); + this->operator++(); + } + + const IterationResult operator*() const { + return IterationResult(mTrieMap, mKey, mValue, mNextLevelBitmapEntryIndex); + } + + bool operator!=(const TrieMapIterator &other) const { + // Caveat: This works only for for loops. + return mIsValid || other.mIsValid; + } + + const TrieMapIterator &operator++() { + const Result result = mTrieMap->iterateNext(&mStateStack, &mKey); + mValue = result.mValue; + mIsValid = result.mIsValid; + mNextLevelBitmapEntryIndex = result.mNextLevelBitmapEntryIndex; + return *this; + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapIterator); + DISALLOW_ASSIGNMENT_OPERATOR(TrieMapIterator); + + const TrieMap *const mTrieMap; + std::vector mStateStack; + const int mBaseBitmapEntryIndex; + int mKey; + uint64_t mValue; + bool mIsValid; + int mNextLevelBitmapEntryIndex; + }; + + /** + * Class to support iterating entries in TrieMap by range base for loops. + */ + class TrieMapRange { + public: + TrieMapRange(const TrieMap *const trieMap, const int bitmapEntryIndex) + : mTrieMap(trieMap), mBaseBitmapEntryIndex(bitmapEntryIndex) {}; + + TrieMapIterator begin() const { + return TrieMapIterator(mTrieMap, mBaseBitmapEntryIndex); + } + + const TrieMapIterator end() const { + return TrieMapIterator(nullptr, INVALID_INDEX); + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapRange); + DISALLOW_ASSIGNMENT_OPERATOR(TrieMapRange); + + const TrieMap *const mTrieMap; + const int mBaseBitmapEntryIndex; + }; + static const int INVALID_INDEX; static const uint64_t MAX_VALUE; @@ -73,6 +184,14 @@ class TrieMap { bool put(const int key, const uint64_t value, const int bitmapEntryIndex); + const TrieMapRange getEntriesInRootLevel() const { + return getEntriesInSpecifiedLevel(ROOT_BITMAP_ENTRY_INDEX); + } + + const TrieMapRange getEntriesInSpecifiedLevel(const int bitmapEntryIndex) const { + return TrieMapRange(this, bitmapEntryIndex); + } + private: DISALLOW_COPY_AND_ASSIGN(TrieMap); @@ -171,6 +290,8 @@ class TrieMap { bool addNewEntryByExpandingTable(const uint32_t key, const uint64_t value, const int tableIndex, const uint32_t bitmap, const int bitmapEntryIndex, const int label); + const Result iterateNext(std::vector *const iterationState, + int *const outKey) const; AK_FORCE_INLINE const Entry readEntry(const int entryIndex) const { return Entry(readField0(entryIndex), readField1(entryIndex)); diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp index 5dd782277..df778b6cf 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp @@ -54,7 +54,7 @@ TEST(TrieMapTest, TestSetAndGetLarge) { EXPECT_TRUE(trieMap.putRoot(i, i)); } for (int i = 0; i < ELEMENT_COUNT; ++i) { - EXPECT_EQ(trieMap.getRoot(i).mValue, static_cast(i)); + EXPECT_EQ(static_cast(i), trieMap.getRoot(i).mValue); } } @@ -78,7 +78,7 @@ TEST(TrieMapTest, TestRandSetAndGetLarge) { testKeyValuePairs[key] = value; } for (const auto &v : testKeyValuePairs) { - EXPECT_EQ(trieMap.getRoot(v.first).mValue, v.second); + EXPECT_EQ(v.second, trieMap.getRoot(v.first).mValue); } } @@ -163,6 +163,61 @@ TEST(TrieMapTest, TestMultiLevel) { } } } + + // Iteration + for (const auto &firstLevelEntry : trieMap.getEntriesInRootLevel()) { + EXPECT_EQ(trieMap.getRoot(firstLevelEntry.key()).mValue, firstLevelEntry.value()); + EXPECT_EQ(firstLevelEntries[firstLevelEntry.key()], firstLevelEntry.value()); + firstLevelEntries.erase(firstLevelEntry.key()); + for (const auto &secondLevelEntry : firstLevelEntry.getEntriesInNextLevel()) { + EXPECT_EQ(twoLevelMap[firstLevelEntry.key()][secondLevelEntry.key()], + secondLevelEntry.value()); + twoLevelMap[firstLevelEntry.key()].erase(secondLevelEntry.key()); + for (const auto &thirdLevelEntry : secondLevelEntry.getEntriesInNextLevel()) { + EXPECT_EQ(threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()] + [thirdLevelEntry.key()], thirdLevelEntry.value()); + threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()].erase( + thirdLevelEntry.key()); + } + } + } + + // Ensure all entries have been traversed. + EXPECT_TRUE(firstLevelEntries.empty()); + for (const auto &secondLevelEntry : twoLevelMap) { + EXPECT_TRUE(secondLevelEntry.second.empty()); + } + for (const auto &secondLevelEntry : threeLevelMap) { + for (const auto &thirdLevelEntry : secondLevelEntry.second) { + EXPECT_TRUE(thirdLevelEntry.second.empty()); + } + } +} + +TEST(TrieMapTest, TestIteration) { + static const int ELEMENT_COUNT = 200000; + TrieMap trieMap; + std::unordered_map testKeyValuePairs; + + // Use the uniform integer distribution [S_INT_MIN, S_INT_MAX]. + std::uniform_int_distribution keyDistribution(S_INT_MIN, S_INT_MAX); + auto keyRandomNumberGenerator = std::bind(keyDistribution, std::mt19937()); + + // Use the uniform distribution [0, TrieMap::MAX_VALUE]. + std::uniform_int_distribution valueDistribution(0, TrieMap::MAX_VALUE); + auto valueRandomNumberGenerator = std::bind(valueDistribution, std::mt19937()); + for (int i = 0; i < ELEMENT_COUNT; ++i) { + const int key = keyRandomNumberGenerator(); + const uint64_t value = valueRandomNumberGenerator(); + EXPECT_TRUE(trieMap.putRoot(key, value)); + testKeyValuePairs[key] = value; + } + for (const auto &entry : trieMap.getEntriesInRootLevel()) { + EXPECT_EQ(trieMap.getRoot(entry.key()).mValue, entry.value()); + EXPECT_EQ(testKeyValuePairs[entry.key()], entry.value()); + testKeyValuePairs.erase(entry.key()); + } + EXPECT_TRUE(testKeyValuePairs.empty()); } } // namespace