From 07b3b41c25e000615396399e484a041df9301449 Mon Sep 17 00:00:00 2001 From: Keisuke Kuroyanagi Date: Tue, 26 Aug 2014 12:01:08 +0900 Subject: [PATCH] Add a method to iterate entries in LanguageModelDictContent. Bug: 14425059 Change-Id: I4e9c3a97891c020f762fa709f806d333c067f496 --- .../content/language_model_dict_content.cpp | 6 ++ .../v4/content/language_model_dict_content.h | 71 +++++++++++++++++++ .../policyimpl/dictionary/utils/trie_map.h | 2 +- .../language_model_dict_content_test.cpp | 20 ++++++ 4 files changed, 98 insertions(+), 1 deletion(-) diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index ea2d24e67..eb2b1ec3e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -71,6 +71,12 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView return mTrieMap.remove(wordId, bitmapEntryIndex); } +LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries( + const WordIdArrayView prevWordIds) const { + const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); + return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo); +} + bool LanguageModelDictContent::truncateEntries(const int *const entryCounts, const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) { for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 43b2aab66..961637679 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -39,6 +39,75 @@ class HeaderPolicy; */ class LanguageModelDictContent { public: + // Pair of word id and probability entry used for iteration. + class WordIdAndProbabilityEntry { + public: + WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry) + : mWordId(wordId), mProbabilityEntry(probabilityEntry) {} + + int getWordId() const { return mWordId; } + const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry); + DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry); + + const int mWordId; + const ProbabilityEntry mProbabilityEntry; + }; + + // Iterator. + class EntryIterator { + public: + EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator, + const bool hasHistoricalInfo) + : mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {} + + const WordIdAndProbabilityEntry operator*() const { + const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator; + return WordIdAndProbabilityEntry( + result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo)); + } + + bool operator!=(const EntryIterator &other) const { + return mTrieMapIterator != other.mTrieMapIterator; + } + + const EntryIterator &operator++() { + ++mTrieMapIterator; + return *this; + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator); + DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator); + + TrieMap::TrieMapIterator mTrieMapIterator; + const bool mHasHistoricalInfo; + }; + + // Class represents range to use range base for loops. + class EntryRange { + public: + EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo) + : mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {} + + EntryIterator begin() const { + return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo); + } + + EntryIterator end() const { + return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo); + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange); + DISALLOW_ASSIGNMENT_OPERATOR(EntryRange); + + const TrieMap::TrieMapRange mTrieMapRange; + const bool mHasHistoricalInfo; + }; + LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer, const bool hasHistoricalInfo) : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {} @@ -76,6 +145,8 @@ class LanguageModelDictContent { bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId); + EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const; + bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h index c2aeac211..00765888b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h @@ -98,7 +98,7 @@ class TrieMap { TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex) : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex), mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) { - if (!trieMap) { + if (!trieMap || mBaseBitmapEntryIndex == INVALID_INDEX) { return; } const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex); diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp index 3cacba1c3..ca8d56f27 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp @@ -18,6 +18,8 @@ #include +#include + #include "utils/int_array_view.h" namespace latinime { @@ -69,5 +71,23 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) { EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId)); } +TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) { + LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */); + + const ProbabilityEntry originalEntry(0xFC, 100); + + const int wordIds[] = { 1, 2, 3, 4, 5 }; + for (const int wordId : wordIds) { + languageModelDictContent.setProbabilityEntry(wordId, &originalEntry); + } + std::unordered_set wordIdSet(std::begin(wordIds), std::end(wordIds)); + for (const auto entry : languageModelDictContent.getProbabilityEntries(WordIdArrayView())) { + EXPECT_EQ(originalEntry.getFlags(), entry.getProbabilityEntry().getFlags()); + EXPECT_EQ(originalEntry.getProbability(), entry.getProbabilityEntry().getProbability()); + wordIdSet.erase(entry.getWordId()); + } + EXPECT_TRUE(wordIdSet.empty()); +} + } // namespace } // namespace latinime