Use WordIdArrayView for prevWordIds.

Bug: 14425059
Change-Id: Ia84fb997d89564e60111b46ca83bbfa3b187f316
main
Keisuke Kuroyanagi 2014-09-11 19:36:22 +09:00
parent a3b0eb1685
commit 537f6eea8a
20 changed files with 133 additions and 138 deletions

View File

@ -105,7 +105,7 @@ class DicNode {
}
// Init for root with prevWordIds which is used for n-gram
void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordIds) {
void initAsRoot(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) {
mIsCachedForNextSuggestion = false;
mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds);
mDicNodeState.init();
@ -115,12 +115,11 @@ class DicNode {
// Init for root with previous word
void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
int newPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> newPrevWordIds;
newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId();
for (size_t i = 1; i < NELEMS(newPrevWordIds); ++i) {
newPrevWordIds[i] = dicNode->getPrevWordIds()[i - 1];
}
mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordIds);
dicNode->getPrevWordIds().limit(newPrevWordIds.size() - 1)
.copyToArray(&newPrevWordIds, 1 /* offset */);
mDicNodeProperties.init(rootPtNodeArrayPos, WordIdArrayView::fromArray(newPrevWordIds));
mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
dicNode->mDicNodeProperties.getDepth());
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
@ -203,8 +202,7 @@ class DicNode {
return mDicNodeProperties.getWordId();
}
// TODO: Use view class to return word id array.
const int *getPrevWordIds() const {
const WordIdArrayView getPrevWordIds() const {
return mDicNodeProperties.getPrevWordIds();
}

View File

@ -28,7 +28,7 @@ namespace latinime {
/* static */ void DicNodeUtils::initAsRoot(
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
const int *const prevWordIds, DicNode *const newRootDicNode) {
const WordIdArrayView prevWordIds, DicNode *const newRootDicNode) {
newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds);
}

View File

@ -18,6 +18,7 @@
#define LATINIME_DIC_NODE_UTILS_H
#include "defines.h"
#include "utils/int_array_view.h"
namespace latinime {
@ -30,7 +31,7 @@ class DicNodeUtils {
public:
static void initAsRoot(
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
const int *const prevWordIds, DicNode *const newRootDicNode);
const WordIdArrayView prevWordIds, DicNode *const newRootDicNode);
static void initAsRootWithPreviousWord(
const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode);

View File

@ -18,8 +18,10 @@
#define LATINIME_DIC_NODE_PROPERTIES_H
#include <cstdint>
#include <cstdlib>
#include "defines.h"
#include "utils/int_array_view.h"
namespace latinime {
@ -36,23 +38,23 @@ class DicNodeProperties {
// Should be called only once per DicNode is initialized.
void init(const int childrenPos, const int nodeCodePoint, const int wordId,
const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordIds) {
const uint16_t depth, const uint16_t leavingDepth, const WordIdArrayView prevWordIds) {
mChildrenPtNodeArrayPos = childrenPos;
mDicNodeCodePoint = nodeCodePoint;
mWordId = wordId;
mDepth = depth;
mLeavingDepth = leavingDepth;
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */);
}
// Init for root with prevWordsPtNodePos which is used for n-gram
void init(const int rootPtNodeArrayPos, const int *const prevWordIds) {
void init(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) {
mChildrenPtNodeArrayPos = rootPtNodeArrayPos;
mDicNodeCodePoint = NOT_A_CODE_POINT;
mWordId = NOT_A_WORD_ID;
mDepth = 0;
mLeavingDepth = 0;
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */);
}
void initByCopy(const DicNodeProperties *const dicNodeProp) {
@ -61,7 +63,8 @@ class DicNodeProperties {
mWordId = dicNodeProp->mWordId;
mDepth = dicNodeProp->mDepth;
mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds)
.copyToArray(&mPrevWordIds, 0 /* offset */);
}
// Init as passing child
@ -71,7 +74,8 @@ class DicNodeProperties {
mWordId = dicNodeProp->mWordId;
mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child
mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds)
.copyToArray(&mPrevWordIds, 0 /* offset */);
}
int getChildrenPtNodeArrayPos() const {
@ -99,8 +103,8 @@ class DicNodeProperties {
return (mChildrenPtNodeArrayPos != NOT_A_DICT_POS) || mDepth != mLeavingDepth;
}
const int *getPrevWordIds() const {
return mPrevWordIds;
const WordIdArrayView getPrevWordIds() const {
return WordIdArrayView::fromArray(mPrevWordIds);
}
int getWordId() const {
@ -116,7 +120,7 @@ class DicNodeProperties {
int mWordId;
uint16_t mDepth;
uint16_t mLeavingDepth;
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIds;
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_PROPERTIES_H

View File

@ -85,7 +85,7 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
return;
}
const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
mPrevWordIds.data(), targetWordId, nullptr /* multiBigramMap */);
mPrevWordIds, targetWordId, nullptr /* multiBigramMap */);
mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
wordAttributes.getProbability());
}
@ -93,13 +93,13 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
SuggestionResults *const outSuggestionResults) const {
TimeKeeper::setCurrentTime();
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(),
true /* tryLowerCaseSearch */);
NgramListenerForPrediction listener(prevWordsInfo,
WordIdArrayView::fromFixedSizeArray(prevWordIds), outSuggestionResults,
const WordIdArrayView prevWordIdArrayView = WordIdArrayView::fromArray(prevWordIds);
NgramListenerForPrediction listener(prevWordsInfo, prevWordIdArrayView, outSuggestionResults,
mDictionaryStructureWithBufferPolicy.get());
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener);
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIdArrayView, &listener);
}
int Dictionary::getProbability(const int *word, int length) const {
@ -119,13 +119,13 @@ int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, co
CodePointArrayView(word, length), false /* forceLowerCaseSearch */);
if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY;
if (!prevWordsInfo) {
return getDictionaryStructurePolicy()->getProbabilityOfWord(
nullptr /* prevWordsPtNodePos */, wordId);
return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId);
}
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(),
true /* tryLowerCaseSearch */);
return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId);
return getDictionaryStructurePolicy()->getProbabilityOfWord(
IntArrayView::fromArray(prevWordIds), wordId);
}
bool Dictionary::addUnigramEntry(const int *const word, const int length,

View File

@ -23,6 +23,7 @@
#include "suggest/core/dictionary/digraph_utils.h"
#include "suggest/core/session/prev_words_info.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "utils/int_array_view.h"
namespace latinime {
@ -34,11 +35,12 @@ namespace latinime {
// No prev words information.
PrevWordsInfo emptyPrevWordsInfo;
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds,
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds.data(),
false /* tryLowerCaseSearch */);
current.emplace_back();
DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, &current.front());
DicNodeUtils::initAsRoot(dictionaryStructurePolicy,
IntArrayView::fromArray(prevWordIds), &current.front());
for (int i = 0; i < codePointCount; ++i) {
// The base-lower input is used to ignore case errors and accent errors.
const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]);

View File

@ -35,9 +35,9 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP =
// Also caches the bigrams if there is space remaining and they have not been cached already.
int MultiBigramMap::getBigramProbability(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordIds, const int nextWordId,
const WordIdArrayView prevWordIds, const int nextWordId,
const int unigramProbability) {
if (!prevWordIds || prevWordIds[0] == NOT_A_WORD_ID) {
if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) {
return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY);
}
const auto mapPosition = mBigramMaps.find(prevWordIds[0]);
@ -56,7 +56,7 @@ int MultiBigramMap::getBigramProbability(
void MultiBigramMap::BigramMap::init(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordIds) {
const WordIdArrayView prevWordIds) {
structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */);
}
@ -83,16 +83,13 @@ void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, const i
void MultiBigramMap::addBigramsForWord(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordIds) {
if (prevWordIds) {
mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds);
}
const WordIdArrayView prevWordIds) {
mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds);
}
int MultiBigramMap::readBigramProbabilityFromBinaryDictionary(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordIds, const int nextWordId,
const int unigramProbability) {
const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability) {
const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId);
if (bigramProbability != NOT_A_PROBABILITY) {
return bigramProbability;

View File

@ -25,6 +25,7 @@
#include "suggest/core/dictionary/bloom_filter.h"
#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "utils/int_array_view.h"
namespace latinime {
@ -39,7 +40,7 @@ class MultiBigramMap {
// Look up the bigram probability for the given word pair from the cached bigram maps.
// Also caches the bigrams if there is space remaining and they have not been cached already.
int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordIds, const int nextWordId, const int unigramProbability);
const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability);
void clear() {
mBigramMaps.clear();
@ -57,7 +58,7 @@ class MultiBigramMap {
virtual ~BigramMap() {}
void init(const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordIds);
const WordIdArrayView prevWordIds);
int getBigramProbability(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int nextWordId, const int unigramProbability) const;
@ -70,11 +71,11 @@ class MultiBigramMap {
};
void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordIds);
const WordIdArrayView prevWordIds);
int readBigramProbabilityFromBinaryDictionary(
const DictionaryStructureWithBufferPolicy *const structurePolicy,
const int *const prevWordIds, const int nextWordId, const int unigramProbability);
const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability);
static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP;
std::unordered_map<int, BigramMap> mBigramMaps;

View File

@ -58,15 +58,15 @@ class DictionaryStructureWithBufferPolicy {
virtual int getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const = 0;
virtual const WordAttributes getWordAttributesInContext(const int *const prevWordIds,
virtual const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
const int wordId, MultiBigramMap *const multiBigramMap) const = 0;
// TODO: Remove
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0;
virtual int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const = 0;
virtual void iterateNgramEntries(const int *const prevWordIds,
virtual void iterateNgramEntries(const WordIdArrayView prevWordIds,
NgramListener *const listener) const = 0;
virtual BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const = 0;

View File

@ -35,7 +35,7 @@ void DicTraverseSession::init(const Dictionary *const dictionary,
mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy()
->getMultiWordCostMultiplier();
mSuggestOptions = suggestOptions;
prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds,
prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds.data(),
true /* tryLowerCaseSearch */);
}

View File

@ -24,6 +24,7 @@
#include "suggest/core/dicnode/dic_nodes_cache.h"
#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/layout/proximity_info_state.h"
#include "utils/int_array_view.h"
namespace latinime {
@ -55,9 +56,7 @@ class DicTraverseSession {
mMultiWordCostMultiplier(1.0f) {
// NOTE: mProximityInfoStates is an array of instances.
// No need to initialize it explicitly here.
for (size_t i = 0; i < NELEMS(mPrevWordsIds); ++i) {
mPrevWordsIds[i] = NOT_A_DICT_POS;
}
mPrevWordsIds.fill(NOT_A_DICT_POS);
}
// Non virtual inline destructor -- never inherit this class
@ -79,7 +78,7 @@ class DicTraverseSession {
//--------------------
const ProximityInfo *getProximityInfo() const { return mProximityInfo; }
const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; }
const int *getPrevWordIds() const { return mPrevWordsIds; }
const WordIdArrayView getPrevWordIds() const { return IntArrayView::fromArray(mPrevWordsIds); }
DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; }
const ProximityInfoState *getProximityInfoState(int id) const {
@ -166,7 +165,7 @@ class DicTraverseSession {
const int *const inputYs, const int *const times, const int *const pointerIds,
const int inputSize, const float maxSpatialDistance, const int maxPointerCount);
int mPrevWordsIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordsIds;
const ProximityInfo *mProximityInfo;
const Dictionary *mDictionary;
const SuggestOptions *mSuggestOptions;

View File

@ -116,7 +116,7 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
}
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
const int *const prevWordIds, const int wordId,
const WordIdArrayView prevWordIds, const int wordId,
MultiBigramMap *const multiBigramMap) const {
if (wordId == NOT_A_WORD_ID) {
return WordAttributes();
@ -128,7 +128,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
prevWordIds, wordId, ptNodeParams.getProbability());
return getWordAttributes(probability, ptNodeParams);
}
if (prevWordIds) {
if (!prevWordIds.empty()) {
const int probability = getProbabilityOfWord(prevWordIds, wordId);
if (probability != NOT_A_PROBABILITY) {
return getWordAttributes(probability, ptNodeParams);
@ -160,7 +160,7 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
}
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
const int wordId) const {
if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY;
@ -170,7 +170,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
if (prevWordIds) {
if (!prevWordIds.empty()) {
const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
@ -186,9 +186,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
NgramListener *const listener) const {
if (!prevWordIds) {
if (prevWordIds.empty()) {
return;
}
const int bigramsPosition = getBigramsPositionOfPtNode(

View File

@ -91,14 +91,15 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
MultiBigramMap *const multiBigramMap) const;
const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
const int wordId, MultiBigramMap *const multiBigramMap) const;
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const;
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
void iterateNgramEntries(const WordIdArrayView prevWordIds,
NgramListener *const listener) const;
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;

View File

@ -282,8 +282,9 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
return getWordIdFromTerminalPtNodePos(ptNodePos);
}
const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *const prevWordIds,
const int wordId, MultiBigramMap *const multiBigramMap) const {
const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(
const WordIdArrayView prevWordIds, const int wordId,
MultiBigramMap *const multiBigramMap) const {
if (wordId == NOT_A_WORD_ID) {
return WordAttributes();
}
@ -295,7 +296,7 @@ const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *c
prevWordIds, wordId, ptNodeParams.getProbability());
return getWordAttributes(probability, ptNodeParams);
}
if (prevWordIds) {
if (!prevWordIds.empty()) {
const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId);
if (bigramProbability != NOT_A_PROBABILITY) {
return getWordAttributes(bigramProbability, ptNodeParams);
@ -327,7 +328,8 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability,
}
}
int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const int wordId) const {
int PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
const int wordId) const {
if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY;
}
@ -340,7 +342,7 @@ int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const
// for shortcuts).
return NOT_A_PROBABILITY;
}
if (prevWordIds) {
if (!prevWordIds.empty()) {
const int bigramsPosition = getBigramsPositionOfPtNode(
getTerminalPtNodePosFromWordId(prevWordIds[0]));
BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition);
@ -356,9 +358,9 @@ int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
void PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
NgramListener *const listener) const {
if (!prevWordIds) {
if (prevWordIds.empty()) {
return;
}
const int bigramsPosition = getBigramsPositionOfPtNode(
@ -371,8 +373,7 @@ void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
}
}
BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator(
const int wordId) const {
BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator(const int wordId) const {
const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId));
return BinaryDictionaryShortcutIterator(&mShortcutListPolicy, shortcutPos);
}

View File

@ -66,14 +66,15 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
MultiBigramMap *const multiBigramMap) const;
const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
const int wordId, MultiBigramMap *const multiBigramMap) const;
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const;
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
void iterateNgramEntries(const WordIdArrayView prevWordIds,
NgramListener *const listener) const;
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;

View File

@ -16,6 +16,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h"
#include <array>
#include <vector>
#include "suggest/core/dicnode/dic_node.h"
@ -111,7 +112,7 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
}
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
const int *const prevWordIds, const int wordId,
const WordIdArrayView prevWordIds, const int wordId,
MultiBigramMap *const multiBigramMap) const {
if (wordId == NOT_A_WORD_ID) {
return WordAttributes();
@ -121,27 +122,11 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
// TODO: Support n-gram.
return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
WordIdArrayView::singleElementView(prevWordIds), wordId), ptNodeParams.isBlacklisted(),
prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(),
ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
}
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
const int bigramProbability) const {
if (mHeaderPolicy->isDecayingDict()) {
// Both probabilities are encoded. Decode them and get probability.
return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability);
} else {
if (unigramProbability == NOT_A_PROBABILITY) {
return NOT_A_PROBABILITY;
} else if (bigramProbability == NOT_A_PROBABILITY) {
return ProbabilityUtils::backoff(unigramProbability);
} else {
return bigramProbability;
}
}
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
const int wordId) const {
if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY;
@ -152,22 +137,19 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
if (prevWordIds) {
// TODO: Support n-gram.
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
IntArrayView::singleElementView(prevWordIds), wordId);
if (!probabilityEntry.isValid()) {
return NOT_A_PROBABILITY;
}
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
mHeaderPolicy);
} else {
return probabilityEntry.getProbability();
}
// TODO: Support n-gram.
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
prevWordIds.limit(1 /* maxSize */), wordId);
if (!probabilityEntry.isValid()) {
return NOT_A_PROBABILITY;
}
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
mHeaderPolicy);
} else {
return probabilityEntry.getProbability();
}
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
@ -176,15 +158,15 @@ BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
return BinaryDictionaryShortcutIterator(&mShortcutPolicy, shortcutPos);
}
void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds,
NgramListener *const listener) const {
if (!prevWordIds) {
if (prevWordIds.empty()) {
return;
}
// TODO: Support n-gram.
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
for (const auto entry : languageModelDictContent->getProbabilityEntries(
WordIdArrayView::singleElementView(prevWordIds))) {
prevWordIds.limit(1 /* maxSize */))) {
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
@ -321,8 +303,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
"length: %zd", bigramProperty->getTargetCodePoints()->size());
return false;
}
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */);
// TODO: Support N-gram.
if (prevWordIds[0] == NOT_A_WORD_ID) {
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
@ -337,7 +319,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
return false;
}
// Refresh word ids.
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */);
} else {
return false;
}
@ -348,14 +330,14 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
return false;
}
bool addedNewEntry = false;
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos;
for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) {
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(prevWordIds[i]);
}
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(wordId);
if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos),
wordPtNodePos, bigramProperty, &addedNewEntry)) {
if (addedNewEntry) {
mBigramCount++;
@ -385,8 +367,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
wordCodePoints.size());
}
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */);
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSerch */);
// TODO: Support N-gram.
if (prevWordIds[0] == NOT_A_WORD_ID) {
return false;
@ -395,14 +377,14 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
if (wordId == NOT_A_WORD_ID) {
return false;
}
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
for (size_t i = 0; i < NELEMS(prevWordIds); ++i) {
std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos;
for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) {
prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(prevWordIds[i]);
}
const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable()
->getTerminalPtNodePosition(wordId);
if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos),
if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos),
wordPtNodePos)) {
mBigramCount--;
return true;

View File

@ -68,14 +68,19 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
MultiBigramMap *const multiBigramMap) const;
const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds,
const int wordId, MultiBigramMap *const multiBigramMap) const;
int getProbability(const int unigramProbability, const int bigramProbability) const;
// TODO: Remove
int getProbability(const int unigramProbability, const int bigramProbability) const {
// Not used.
return NOT_A_PROBABILITY;
}
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const;
void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const;
void iterateNgramEntries(const WordIdArrayView prevWordIds,
NgramListener *const listener) const;
BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const;

View File

@ -57,9 +57,9 @@ class IntArrayView {
explicit IntArrayView(const std::vector<int> &vector)
: mPtr(vector.data()), mSize(vector.size()) {}
template <int N>
AK_FORCE_INLINE static IntArrayView fromFixedSizeArray(const int (&array)[N]) {
return IntArrayView(array, N);
template <size_t N>
AK_FORCE_INLINE static IntArrayView fromArray(const std::array<int, N> &array) {
return IntArrayView(array.data(), array.size());
}
// Returns a view that points one int object.
@ -120,6 +120,8 @@ class IntArrayView {
using WordIdArrayView = IntArrayView;
using PtNodePosArrayView = IntArrayView;
using CodePointArrayView = IntArrayView;
template <size_t size>
using WordIdArray = std::array<int, size>;
} // namespace latinime
#endif // LATINIME_MEMORY_VIEW_H

View File

@ -18,6 +18,7 @@
#include <gtest/gtest.h>
#include <array>
#include <unordered_set>
#include "utils/int_array_view.h"
@ -97,8 +98,8 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
const int bigramProbability = 20;
const int trigramProbability = 30;
const int wordId = 100;
const int prevWordIdArray[] = { 1, 2 };
const WordIdArrayView prevWordIds = WordIdArrayView::fromFixedSizeArray(prevWordIdArray);
const std::array<int, 2> prevWordIdArray = {{ 1, 2 }};
const WordIdArrayView prevWordIds = WordIdArrayView::fromArray(prevWordIdArray);
const ProbabilityEntry probabilityEntry(flag, probability);
languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);

View File

@ -46,8 +46,8 @@ TEST(IntArrayViewTest, TestIteration) {
TEST(IntArrayViewTest, TestConstructFromArray) {
const size_t ARRAY_SIZE = 100;
int intArray[ARRAY_SIZE];
const auto intArrayView = IntArrayView::fromFixedSizeArray(intArray);
std::array<int, ARRAY_SIZE> intArray;
const auto intArrayView = IntArrayView::fromArray(intArray);
EXPECT_EQ(ARRAY_SIZE, intArrayView.size());
}