Merge "Add a method to iterate entries in LanguageModelDictContent."
commit
82f7d3a9de
|
@ -71,6 +71,12 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
|
|||
return mTrieMap.remove(wordId, bitmapEntryIndex);
|
||||
}
|
||||
|
||||
LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries(
|
||||
const WordIdArrayView prevWordIds) const {
|
||||
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
|
||||
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
|
||||
}
|
||||
|
||||
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
|
||||
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
|
||||
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
||||
|
|
|
@ -39,6 +39,75 @@ class HeaderPolicy;
|
|||
*/
|
||||
class LanguageModelDictContent {
|
||||
public:
|
||||
// Pair of word id and probability entry used for iteration.
|
||||
class WordIdAndProbabilityEntry {
|
||||
public:
|
||||
WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry)
|
||||
: mWordId(wordId), mProbabilityEntry(probabilityEntry) {}
|
||||
|
||||
int getWordId() const { return mWordId; }
|
||||
const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; }
|
||||
|
||||
private:
|
||||
DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry);
|
||||
DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry);
|
||||
|
||||
const int mWordId;
|
||||
const ProbabilityEntry mProbabilityEntry;
|
||||
};
|
||||
|
||||
// Iterator.
|
||||
class EntryIterator {
|
||||
public:
|
||||
EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator,
|
||||
const bool hasHistoricalInfo)
|
||||
: mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {}
|
||||
|
||||
const WordIdAndProbabilityEntry operator*() const {
|
||||
const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator;
|
||||
return WordIdAndProbabilityEntry(
|
||||
result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo));
|
||||
}
|
||||
|
||||
bool operator!=(const EntryIterator &other) const {
|
||||
return mTrieMapIterator != other.mTrieMapIterator;
|
||||
}
|
||||
|
||||
const EntryIterator &operator++() {
|
||||
++mTrieMapIterator;
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator);
|
||||
DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator);
|
||||
|
||||
TrieMap::TrieMapIterator mTrieMapIterator;
|
||||
const bool mHasHistoricalInfo;
|
||||
};
|
||||
|
||||
// Class represents range to use range base for loops.
|
||||
class EntryRange {
|
||||
public:
|
||||
EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo)
|
||||
: mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {}
|
||||
|
||||
EntryIterator begin() const {
|
||||
return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo);
|
||||
}
|
||||
|
||||
EntryIterator end() const {
|
||||
return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo);
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange);
|
||||
DISALLOW_ASSIGNMENT_OPERATOR(EntryRange);
|
||||
|
||||
const TrieMap::TrieMapRange mTrieMapRange;
|
||||
const bool mHasHistoricalInfo;
|
||||
};
|
||||
|
||||
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
|
||||
const bool hasHistoricalInfo)
|
||||
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
|
||||
|
@ -76,6 +145,8 @@ class LanguageModelDictContent {
|
|||
|
||||
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
|
||||
|
||||
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
|
||||
|
||||
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
|
||||
int *const outEntryCounts) {
|
||||
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
||||
|
|
|
@ -98,7 +98,7 @@ class TrieMap {
|
|||
TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex)
|
||||
: mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex),
|
||||
mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) {
|
||||
if (!trieMap) {
|
||||
if (!trieMap || mBaseBitmapEntryIndex == INVALID_INDEX) {
|
||||
return;
|
||||
}
|
||||
const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex);
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "utils/int_array_view.h"
|
||||
|
||||
namespace latinime {
|
||||
|
@ -69,5 +71,23 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
|
|||
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
|
||||
}
|
||||
|
||||
TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
|
||||
LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
|
||||
|
||||
const ProbabilityEntry originalEntry(0xFC, 100);
|
||||
|
||||
const int wordIds[] = { 1, 2, 3, 4, 5 };
|
||||
for (const int wordId : wordIds) {
|
||||
languageModelDictContent.setProbabilityEntry(wordId, &originalEntry);
|
||||
}
|
||||
std::unordered_set<int> wordIdSet(std::begin(wordIds), std::end(wordIds));
|
||||
for (const auto entry : languageModelDictContent.getProbabilityEntries(WordIdArrayView())) {
|
||||
EXPECT_EQ(originalEntry.getFlags(), entry.getProbabilityEntry().getFlags());
|
||||
EXPECT_EQ(originalEntry.getProbability(), entry.getProbabilityEntry().getProbability());
|
||||
wordIdSet.erase(entry.getWordId());
|
||||
}
|
||||
EXPECT_TRUE(wordIdSet.empty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace latinime
|
||||
|
|
Loading…
Reference in New Issue