Add methods for unigrams to LanguageModelDictContent.

Bug: 14425059
Change-Id: I0a6b480a3d4735787ffac68c47b4ffefc3f1b8a5
This commit is contained in:
Keisuke Kuroyanagi 2014-08-05 12:38:55 +09:00
parent 85b7b967b7
commit 0889484266
8 changed files with 277 additions and 2 deletions

View file

@ -126,6 +126,8 @@ LATIN_IME_CORE_TEST_FILES := \
defines_test.cpp \
suggest/core/layout/normal_distribution_2d_test.cpp \
suggest/core/dictionary/bloom_filter_test.cpp \
suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp \
suggest/policyimpl/dictionary/structure/v4/content/probability_entry_test.cpp \
suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer_test.cpp \
suggest/policyimpl/dictionary/utils/trie_map_test.cpp \
utils/autocorrection_threshold_utils_test.cpp \

View file

@ -22,4 +22,63 @@ bool LanguageModelDictContent::save(FILE *const file) const {
return mTrieMap.save(file);
}
bool LanguageModelDictContent::runGC(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const LanguageModelDictContent *const originalContent,
int *const outNgramCount) {
return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}
ProbabilityEntry LanguageModelDictContent::getProbabilityEntry(
const WordIdArrayView prevWordIds, const int wordId) const {
if (!prevWordIds.empty()) {
// TODO: Read n-gram entry.
return ProbabilityEntry();
}
const TrieMap::Result result = mTrieMap.getRoot(wordId);
if (!result.mIsValid) {
// Not found.
return ProbabilityEntry();
}
return ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
}
bool LanguageModelDictContent::setProbabilityEntry(const WordIdArrayView prevWordIds,
const int terminalId, const ProbabilityEntry *const probabilityEntry) {
if (!prevWordIds.empty()) {
// TODO: Add n-gram entry.
return false;
}
return mTrieMap.putRoot(terminalId, probabilityEntry->encode(mHasHistoricalInfo));
}
bool LanguageModelDictContent::runGCInner(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange,
const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
for (auto &entry : trieMapRange) {
const auto it = terminalIdMap->find(entry.key());
if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
// The word has been removed.
continue;
}
if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
return false;
}
if (outNgramCount) {
*outNgramCount += 1;
}
if (entry.hasNextLevelMap()) {
if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex),
outNgramCount)) {
return false;
}
}
}
return true;
}
} // namespace latinime

View file

@ -20,25 +20,53 @@
#include <cstdio>
#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/trie_map.h"
#include "utils/byte_array_view.h"
#include "utils/int_array_view.h"
namespace latinime {
/**
* Class representing language model.
*
* This class provides methods to get and store unigram/n-gram probability information and flags.
*/
class LanguageModelDictContent {
public:
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
const bool hasHistoricalInfo)
: mTrieMap(trieMapBuffer) {}
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
explicit LanguageModelDictContent(const bool hasHistoricalInfo) : mTrieMap() {}
explicit LanguageModelDictContent(const bool hasHistoricalInfo)
: mTrieMap(), mHasHistoricalInfo(hasHistoricalInfo) {}
bool isNearSizeLimit() const {
return mTrieMap.isNearSizeLimit();
}
bool save(FILE *const file) const;
bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const LanguageModelDictContent *const originalContent,
int *const outNgramCount);
ProbabilityEntry getProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId) const;
bool setProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId,
const ProbabilityEntry *const probabilityEntry);
private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
TrieMap mTrieMap;
const bool mHasHistoricalInfo;
bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
int *const outNgramCount);
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */

View file

@ -17,6 +17,9 @@
#ifndef LATINIME_PROBABILITY_ENTRY_H
#define LATINIME_PROBABILITY_ENTRY_H
#include <climits>
#include <cstdint>
#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/historical_info.h"
@ -67,6 +70,50 @@ class ProbabilityEntry {
return &mHistoricalInfo;
}
uint64_t encode(const bool hasHistoricalInfo) const {
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
if (hasHistoricalInfo) {
encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mHistoricalInfo.getTimeStamp());
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mHistoricalInfo.getLevel());
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mHistoricalInfo.getCount());
} else {
encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
^ static_cast<uint64_t>(mProbability);
}
return encodedEntry;
}
static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
if (hasHistoricalInfo) {
const int flags = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
Ver4DictConstants::TIME_STAMP_FIELD_SIZE
+ Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
+ Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
const int timestamp = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::TIME_STAMP_FIELD_SIZE,
Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
+ Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
const int level = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::WORD_LEVEL_FIELD_SIZE,
Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
const int count = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */);
const HistoricalInfo historicalInfo(timestamp, level, count);
return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
} else {
const int flags = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
Ver4DictConstants::PROBABILITY_SIZE);
const int probability = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);
return ProbabilityEntry(flags, probability);
}
}
private:
// Copy constructor is public to use this class as a type of return value.
DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry);
@ -74,6 +121,11 @@ class ProbabilityEntry {
const int mFlags;
const int mProbability;
const HistoricalInfo mHistoricalInfo;
static int readFromEncodedEntry(const uint64_t encodedEntry, const int size, const int pos) {
return static_cast<int>(
(encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
}
};
} // namespace latinime
#endif /* LATINIME_PROBABILITY_ENTRY_H */

View file

@ -90,6 +90,14 @@ class Ver4DictBuffers {
return &mProbabilityDictContent;
}
AK_FORCE_INLINE LanguageModelDictContent *getMutableLanguageModelDictContent() {
return &mLanguageModelDictContent;
}
AK_FORCE_INLINE const LanguageModelDictContent *getLanguageModelDictContent() const {
return &mLanguageModelDictContent;
}
AK_FORCE_INLINE BigramDictContent *getMutableBigramDictContent() {
return &mBigramDictContent;
}

View file

@ -61,6 +61,10 @@ class IntArrayView {
return mPtr[index];
}
AK_FORCE_INLINE bool empty() const {
return size() == 0;
}
AK_FORCE_INLINE size_t size() const {
return mSize;
}
@ -76,5 +80,7 @@ class IntArrayView {
const size_t mSize;
};
using WordIdArrayView = IntArrayView;
} // namespace latinime
#endif // LATINIME_MEMORY_VIEW_H

View file

@ -0,0 +1,60 @@
/*
* Copyright (C) 2014 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
#include <gtest/gtest.h>
#include "utils/int_array_view.h"
namespace latinime {
namespace {
TEST(LanguageModelDictContentTest, TestUnigramProbability) {
LanguageModelDictContent LanguageModelDictContent(false /* useHistoricalInfo */);
const int flag = 0xFF;
const int probability = 10;
const int wordId = 100;
const ProbabilityEntry probabilityEntry(flag, probability);
LanguageModelDictContent.setProbabilityEntry(WordIdArrayView(), wordId, &probabilityEntry);
const ProbabilityEntry entry =
LanguageModelDictContent.getProbabilityEntry(WordIdArrayView(), wordId);
EXPECT_EQ(flag, entry.getFlags());
EXPECT_EQ(probability, entry.getProbability());
}
TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
LanguageModelDictContent LanguageModelDictContent(true /* useHistoricalInfo */);
const int flag = 0xF0;
const int timestamp = 0x3FFFFFFF;
const int level = 3;
const int count = 10;
const int wordId = 100;
const HistoricalInfo historicalInfo(timestamp, level, count);
const ProbabilityEntry probabilityEntry(flag, NOT_A_PROBABILITY, &historicalInfo);
LanguageModelDictContent.setProbabilityEntry(WordIdArrayView(), wordId, &probabilityEntry);
const ProbabilityEntry entry =
LanguageModelDictContent.getProbabilityEntry(WordIdArrayView(), wordId);
EXPECT_EQ(flag, entry.getFlags());
EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp());
EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel());
EXPECT_EQ(count, entry.getHistoricalInfo()->getCount());
}
} // namespace
} // namespace latinime

View file

@ -0,0 +1,60 @@
/*
* Copyright (C) 2014 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
#include <gtest/gtest.h>
#include "defines.h"
namespace latinime {
namespace {
TEST(ProbabilityEntryTest, TestEncodeDecode) {
const int flag = 0xFF;
const int probability = 10;
const ProbabilityEntry entry(flag, probability);
const uint64_t encodedEntry = entry.encode(false /* hasHistoricalInfo */);
const ProbabilityEntry decodedEntry =
ProbabilityEntry::decode(encodedEntry, false /* hasHistoricalInfo */);
EXPECT_EQ(0xFF0Aull, encodedEntry);
EXPECT_EQ(flag, decodedEntry.getFlags());
EXPECT_EQ(probability, decodedEntry.getProbability());
}
TEST(ProbabilityEntryTest, TestEncodeDecodeWithHistoricalInfo) {
const int flag = 0xF0;
const int timestamp = 0x3FFFFFFF;
const int level = 3;
const int count = 10;
const HistoricalInfo historicalInfo(timestamp, level, count);
const ProbabilityEntry entry(flag, NOT_A_PROBABILITY, &historicalInfo);
const uint64_t encodedEntry = entry.encode(true /* hasHistoricalInfo */);
EXPECT_EQ(0xF03FFFFFFF030Aull, encodedEntry);
const ProbabilityEntry decodedEntry =
ProbabilityEntry::decode(encodedEntry, true /* hasHistoricalInfo */);
EXPECT_EQ(flag, decodedEntry.getFlags());
EXPECT_EQ(timestamp, decodedEntry.getHistoricalInfo()->getTimeStamp());
EXPECT_EQ(level, decodedEntry.getHistoricalInfo()->getLevel());
EXPECT_EQ(count, decodedEntry.getHistoricalInfo()->getCount());
}
} // namespace
} // namespace latinime