Introduce WordAttributes to get word probability and flags.
Bug: 14425059 Change-Id: Iee11d038e0893d7ddd6c52447907f8c55fecb6a5
This commit is contained in:
parent
11a48f92a5
commit
2111e3abc9
10 changed files with 116 additions and 28 deletions
|
@ -72,10 +72,10 @@ namespace latinime {
|
||||||
if (dicNode->hasMultipleWords() && !dicNode->isValidMultipleWordSuggestion()) {
|
if (dicNode->hasMultipleWords() && !dicNode->isValidMultipleWordSuggestion()) {
|
||||||
return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
|
return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
|
||||||
}
|
}
|
||||||
const int probability = dictionaryStructurePolicy->getProbabilityOfWordInContext(
|
const WordAttributes wordAttributes = dictionaryStructurePolicy->getWordAttributesInContext(
|
||||||
dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap);
|
dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap);
|
||||||
// TODO: This equation to calculate the improbability looks unreasonable. Investigate this.
|
// TODO: This equation to calculate the improbability looks unreasonable. Investigate this.
|
||||||
const float cost = static_cast<float>(MAX_PROBABILITY - probability)
|
const float cost = static_cast<float>(MAX_PROBABILITY - wordAttributes.getProbability())
|
||||||
/ static_cast<float>(MAX_PROBABILITY);
|
/ static_cast<float>(MAX_PROBABILITY);
|
||||||
return cost;
|
return cost;
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,9 +84,10 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
|
||||||
if (codePointCount <= 0) {
|
if (codePointCount <= 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int probability = mDictStructurePolicy->getProbabilityOfWordInContext(mPrevWordIds.data(),
|
const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
|
||||||
targetWordId, nullptr /* multiBigramMap */);
|
mPrevWordIds.data(), targetWordId, nullptr /* multiBigramMap */);
|
||||||
mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, probability);
|
mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
|
||||||
|
wordAttributes.getProbability());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
|
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
|
||||||
|
|
60
native/jni/src/suggest/core/dictionary/word_attributes.h
Normal file
60
native/jni/src/suggest/core/dictionary/word_attributes.h
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
/*
|
||||||
|
* Copyright (C) 2014, The Android Open Source Project
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef LATINIME_WORD_ATTRIBUTES_H
|
||||||
|
#define LATINIME_WORD_ATTRIBUTES_H
|
||||||
|
|
||||||
|
#include "defines.h"
|
||||||
|
|
||||||
|
class WordAttributes {
|
||||||
|
public:
|
||||||
|
// Invalid word attributes.
|
||||||
|
WordAttributes()
|
||||||
|
: mProbability(NOT_A_PROBABILITY), mIsBlacklisted(false), mIsNotAWord(false),
|
||||||
|
mIsPossiblyOffensive(false) {}
|
||||||
|
|
||||||
|
WordAttributes(const int probability, const bool isBlacklisted, const bool isNotAWord,
|
||||||
|
const bool isPossiblyOffensive)
|
||||||
|
: mProbability(probability), mIsBlacklisted(isBlacklisted), mIsNotAWord(isNotAWord),
|
||||||
|
mIsPossiblyOffensive(isPossiblyOffensive) {}
|
||||||
|
|
||||||
|
int getProbability() const {
|
||||||
|
return mProbability;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isBlacklisted() const {
|
||||||
|
return mIsBlacklisted;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isNotAWord() const {
|
||||||
|
return mIsNotAWord;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isPossiblyOffensive() const {
|
||||||
|
return mIsPossiblyOffensive;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DISALLOW_ASSIGNMENT_OPERATOR(WordAttributes);
|
||||||
|
|
||||||
|
int mProbability;
|
||||||
|
bool mIsBlacklisted;
|
||||||
|
bool mIsNotAWord;
|
||||||
|
bool mIsPossiblyOffensive;
|
||||||
|
};
|
||||||
|
|
||||||
|
// namespace
|
||||||
|
#endif /* LATINIME_WORD_ATTRIBUTES_H */
|
|
@ -22,6 +22,7 @@
|
||||||
#include "defines.h"
|
#include "defines.h"
|
||||||
#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h"
|
#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h"
|
||||||
#include "suggest/core/dictionary/property/word_property.h"
|
#include "suggest/core/dictionary/property/word_property.h"
|
||||||
|
#include "suggest/core/dictionary/word_attributes.h"
|
||||||
#include "utils/int_array_view.h"
|
#include "utils/int_array_view.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
@ -57,8 +58,8 @@ class DictionaryStructureWithBufferPolicy {
|
||||||
virtual int getWordId(const CodePointArrayView wordCodePoints,
|
virtual int getWordId(const CodePointArrayView wordCodePoints,
|
||||||
const bool forceLowerCaseSearch) const = 0;
|
const bool forceLowerCaseSearch) const = 0;
|
||||||
|
|
||||||
virtual int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId,
|
virtual const WordAttributes getWordAttributesInContext(const int *const prevWordIds,
|
||||||
MultiBigramMap *const multiBigramMap) const = 0;
|
const int wordId, MultiBigramMap *const multiBigramMap) const = 0;
|
||||||
|
|
||||||
// TODO: Remove
|
// TODO: Remove
|
||||||
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
|
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
|
||||||
|
|
|
@ -118,24 +118,33 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
||||||
return getWordIdFromTerminalPtNodePos(ptNodePos);
|
return getWordIdFromTerminalPtNodePos(ptNodePos);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
|
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
||||||
const int wordId, MultiBigramMap *const multiBigramMap) const {
|
const int *const prevWordIds, const int wordId,
|
||||||
|
MultiBigramMap *const multiBigramMap) const {
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return NOT_A_PROBABILITY;
|
return WordAttributes();
|
||||||
}
|
}
|
||||||
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
|
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
|
||||||
const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
|
const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
|
||||||
if (multiBigramMap) {
|
if (multiBigramMap) {
|
||||||
return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
|
const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */,
|
||||||
wordId, ptNodeParams.getProbability());
|
prevWordIds, wordId, ptNodeParams.getProbability());
|
||||||
|
return getWordAttributes(probability, ptNodeParams);
|
||||||
}
|
}
|
||||||
if (prevWordIds) {
|
if (prevWordIds) {
|
||||||
const int probability = getProbabilityOfWord(prevWordIds, wordId);
|
const int probability = getProbabilityOfWord(prevWordIds, wordId);
|
||||||
if (probability != NOT_A_PROBABILITY) {
|
if (probability != NOT_A_PROBABILITY) {
|
||||||
return probability;
|
return getWordAttributes(probability, ptNodeParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY),
|
||||||
|
ptNodeParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probability,
|
||||||
|
const PtNodeParams &ptNodeParams) const {
|
||||||
|
return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
|
||||||
|
ptNodeParams.getProbability() == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
||||||
|
|
|
@ -91,7 +91,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
|
|
||||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||||
|
|
||||||
int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId,
|
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
|
||||||
MultiBigramMap *const multiBigramMap) const;
|
MultiBigramMap *const multiBigramMap) const;
|
||||||
|
|
||||||
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
||||||
|
@ -166,6 +166,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
int getShortcutPositionOfPtNode(const int ptNodePos) const;
|
int getShortcutPositionOfPtNode(const int ptNodePos) const;
|
||||||
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
|
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
|
||||||
int getTerminalPtNodePosFromWordId(const int wordId) const;
|
int getTerminalPtNodePosFromWordId(const int wordId) const;
|
||||||
|
const WordAttributes getWordAttributes(const int probability,
|
||||||
|
const PtNodeParams &ptNodeParams) const;
|
||||||
};
|
};
|
||||||
} // namespace v402
|
} // namespace v402
|
||||||
} // namespace backward
|
} // namespace backward
|
||||||
|
|
|
@ -282,25 +282,33 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
||||||
return getWordIdFromTerminalPtNodePos(ptNodePos);
|
return getWordIdFromTerminalPtNodePos(ptNodePos);
|
||||||
}
|
}
|
||||||
|
|
||||||
int PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
|
const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *const prevWordIds,
|
||||||
const int wordId, MultiBigramMap *const multiBigramMap) const {
|
const int wordId, MultiBigramMap *const multiBigramMap) const {
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return NOT_A_PROBABILITY;
|
return WordAttributes();
|
||||||
}
|
}
|
||||||
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
|
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
|
||||||
const PtNodeParams ptNodeParams =
|
const PtNodeParams ptNodeParams =
|
||||||
mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
||||||
if (multiBigramMap) {
|
if (multiBigramMap) {
|
||||||
return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
|
const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */,
|
||||||
wordId, ptNodeParams.getProbability());
|
prevWordIds, wordId, ptNodeParams.getProbability());
|
||||||
|
return getWordAttributes(probability, ptNodeParams);
|
||||||
}
|
}
|
||||||
if (prevWordIds) {
|
if (prevWordIds) {
|
||||||
const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId);
|
const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId);
|
||||||
if (bigramProbability != NOT_A_PROBABILITY) {
|
if (bigramProbability != NOT_A_PROBABILITY) {
|
||||||
return bigramProbability;
|
return getWordAttributes(bigramProbability, ptNodeParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
|
return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY),
|
||||||
|
ptNodeParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
const WordAttributes PatriciaTriePolicy::getWordAttributes(const int probability,
|
||||||
|
const PtNodeParams &ptNodeParams) const {
|
||||||
|
return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
|
||||||
|
ptNodeParams.getProbability() == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
int PatriciaTriePolicy::getProbability(const int unigramProbability,
|
int PatriciaTriePolicy::getProbability(const int unigramProbability,
|
||||||
|
|
|
@ -66,7 +66,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
|
|
||||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||||
|
|
||||||
int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId,
|
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
|
||||||
MultiBigramMap *const multiBigramMap) const;
|
MultiBigramMap *const multiBigramMap) const;
|
||||||
|
|
||||||
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
||||||
|
@ -163,6 +163,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
DicNodeVector *const childDicNodes) const;
|
DicNodeVector *const childDicNodes) const;
|
||||||
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
|
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
|
||||||
int getTerminalPtNodePosFromWordId(const int wordId) const;
|
int getTerminalPtNodePosFromWordId(const int wordId) const;
|
||||||
|
const WordAttributes getWordAttributes(const int probability,
|
||||||
|
const PtNodeParams &ptNodeParams) const;
|
||||||
};
|
};
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
#endif // LATINIME_PATRICIA_TRIE_POLICY_H
|
#endif // LATINIME_PATRICIA_TRIE_POLICY_H
|
||||||
|
|
|
@ -113,14 +113,19 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
|
||||||
return ptNodeParams.getTerminalId();
|
return ptNodeParams.getTerminalId();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
|
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
|
||||||
const int wordId, MultiBigramMap *const multiBigramMap) const {
|
const int *const prevWordIds, const int wordId,
|
||||||
|
MultiBigramMap *const multiBigramMap) const {
|
||||||
if (wordId == NOT_A_WORD_ID) {
|
if (wordId == NOT_A_WORD_ID) {
|
||||||
return NOT_A_PROBABILITY;
|
return WordAttributes();
|
||||||
}
|
}
|
||||||
|
const int ptNodePos =
|
||||||
|
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
|
||||||
|
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
|
||||||
// TODO: Support n-gram.
|
// TODO: Support n-gram.
|
||||||
return mBuffers->getLanguageModelDictContent()->getWordProbability(
|
return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
|
||||||
WordIdArrayView::singleElementView(prevWordIds), wordId);
|
WordIdArrayView::singleElementView(prevWordIds), wordId), ptNodeParams.isBlacklisted(),
|
||||||
|
ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
|
||||||
|
|
|
@ -68,7 +68,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
|
||||||
|
|
||||||
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
|
||||||
|
|
||||||
int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId,
|
const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
|
||||||
MultiBigramMap *const multiBigramMap) const;
|
MultiBigramMap *const multiBigramMap) const;
|
||||||
|
|
||||||
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
int getProbability(const int unigramProbability, const int bigramProbability) const;
|
||||||
|
|
Loading…
Reference in a new issue