am 82f7d3a9: Merge "Add a method to iterate entries in LanguageModelDictContent."
* commit '82f7d3a9de4a4029286a6cbc9f890c236c6789cc': Add a method to iterate entries in LanguageModelDictContent.main
commit
c996cd4965
|
@ -71,6 +71,12 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
|
||||||
return mTrieMap.remove(wordId, bitmapEntryIndex);
|
return mTrieMap.remove(wordId, bitmapEntryIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries(
|
||||||
|
const WordIdArrayView prevWordIds) const {
|
||||||
|
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
|
||||||
|
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
|
||||||
|
}
|
||||||
|
|
||||||
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
|
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
|
||||||
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
|
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
|
||||||
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
||||||
|
|
|
@ -39,6 +39,75 @@ class HeaderPolicy;
|
||||||
*/
|
*/
|
||||||
class LanguageModelDictContent {
|
class LanguageModelDictContent {
|
||||||
public:
|
public:
|
||||||
|
// Pair of word id and probability entry used for iteration.
|
||||||
|
class WordIdAndProbabilityEntry {
|
||||||
|
public:
|
||||||
|
WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry)
|
||||||
|
: mWordId(wordId), mProbabilityEntry(probabilityEntry) {}
|
||||||
|
|
||||||
|
int getWordId() const { return mWordId; }
|
||||||
|
const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry);
|
||||||
|
DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry);
|
||||||
|
|
||||||
|
const int mWordId;
|
||||||
|
const ProbabilityEntry mProbabilityEntry;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Iterator.
|
||||||
|
class EntryIterator {
|
||||||
|
public:
|
||||||
|
EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator,
|
||||||
|
const bool hasHistoricalInfo)
|
||||||
|
: mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {}
|
||||||
|
|
||||||
|
const WordIdAndProbabilityEntry operator*() const {
|
||||||
|
const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator;
|
||||||
|
return WordIdAndProbabilityEntry(
|
||||||
|
result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator!=(const EntryIterator &other) const {
|
||||||
|
return mTrieMapIterator != other.mTrieMapIterator;
|
||||||
|
}
|
||||||
|
|
||||||
|
const EntryIterator &operator++() {
|
||||||
|
++mTrieMapIterator;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator);
|
||||||
|
DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator);
|
||||||
|
|
||||||
|
TrieMap::TrieMapIterator mTrieMapIterator;
|
||||||
|
const bool mHasHistoricalInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class represents range to use range base for loops.
|
||||||
|
class EntryRange {
|
||||||
|
public:
|
||||||
|
EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo)
|
||||||
|
: mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {}
|
||||||
|
|
||||||
|
EntryIterator begin() const {
|
||||||
|
return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
EntryIterator end() const {
|
||||||
|
return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange);
|
||||||
|
DISALLOW_ASSIGNMENT_OPERATOR(EntryRange);
|
||||||
|
|
||||||
|
const TrieMap::TrieMapRange mTrieMapRange;
|
||||||
|
const bool mHasHistoricalInfo;
|
||||||
|
};
|
||||||
|
|
||||||
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
|
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
|
||||||
const bool hasHistoricalInfo)
|
const bool hasHistoricalInfo)
|
||||||
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
|
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
|
||||||
|
@ -76,6 +145,8 @@ class LanguageModelDictContent {
|
||||||
|
|
||||||
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
|
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
|
||||||
|
|
||||||
|
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
|
||||||
|
|
||||||
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
|
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
|
||||||
int *const outEntryCounts) {
|
int *const outEntryCounts) {
|
||||||
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
|
||||||
|
|
|
@ -98,7 +98,7 @@ class TrieMap {
|
||||||
TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex)
|
TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex)
|
||||||
: mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex),
|
: mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex),
|
||||||
mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) {
|
mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) {
|
||||||
if (!trieMap) {
|
if (!trieMap || mBaseBitmapEntryIndex == INVALID_INDEX) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex);
|
const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex);
|
||||||
|
|
|
@ -18,6 +18,8 @@
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "utils/int_array_view.h"
|
#include "utils/int_array_view.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
@ -69,5 +71,23 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
|
||||||
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
|
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
|
||||||
|
LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
|
||||||
|
|
||||||
|
const ProbabilityEntry originalEntry(0xFC, 100);
|
||||||
|
|
||||||
|
const int wordIds[] = { 1, 2, 3, 4, 5 };
|
||||||
|
for (const int wordId : wordIds) {
|
||||||
|
languageModelDictContent.setProbabilityEntry(wordId, &originalEntry);
|
||||||
|
}
|
||||||
|
std::unordered_set<int> wordIdSet(std::begin(wordIds), std::end(wordIds));
|
||||||
|
for (const auto entry : languageModelDictContent.getProbabilityEntries(WordIdArrayView())) {
|
||||||
|
EXPECT_EQ(originalEntry.getFlags(), entry.getProbabilityEntry().getFlags());
|
||||||
|
EXPECT_EQ(originalEntry.getProbability(), entry.getProbabilityEntry().getProbability());
|
||||||
|
wordIdSet.erase(entry.getWordId());
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(wordIdSet.empty());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
|
|
Loading…
Reference in New Issue