am 252412d7
: Use additional multi-word cost per language (for Russian)
* commit '252412d7eb4573f91588b06b0fe49ef9f0ac38ac': Use additional multi-word cost per language (for Russian)
This commit is contained in:
commit
93a429a7b5
11 changed files with 69 additions and 65 deletions
|
@ -92,6 +92,7 @@ class BinaryFormat {
|
|||
const int unigramProbability, const int bigramProbability);
|
||||
static int getProbability(const int position, const std::map<int, int> *bigramMap,
|
||||
const uint8_t *bigramFilter, const int unigramProbability);
|
||||
static float getMultiWordCostMultiplier(const uint8_t *const dict);
|
||||
|
||||
// Flags for special processing
|
||||
// Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or
|
||||
|
@ -241,6 +242,17 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *
|
|||
return ((msb & 0x7F) << 8) | dict[(*pos)++];
|
||||
}
|
||||
|
||||
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) {
|
||||
const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE");
|
||||
if (headerValue == S_INT_MIN) {
|
||||
return 1.0f;
|
||||
}
|
||||
if (headerValue <= 0) {
|
||||
return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
|
||||
}
|
||||
return 100.0f / static_cast<float>(headerValue);
|
||||
}
|
||||
|
||||
inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) {
|
||||
return dict[(*pos)++];
|
||||
}
|
||||
|
|
|
@ -424,10 +424,9 @@ typedef enum {
|
|||
CT_OMISSION,
|
||||
CT_INSERTION,
|
||||
CT_TRANSPOSITION,
|
||||
CT_SPACE_SUBSTITUTION,
|
||||
CT_SPACE_OMISSION,
|
||||
CT_COMPLETION,
|
||||
CT_TERMINAL,
|
||||
CT_NEW_WORD,
|
||||
CT_NEW_WORD_SPACE_OMITTION,
|
||||
CT_NEW_WORD_SPACE_SUBSTITUTION,
|
||||
} CorrectionType;
|
||||
#endif // LATINIME_DEFINES_H
|
||||
|
|
|
@ -38,7 +38,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
case CT_SUBSTITUTION:
|
||||
PROF_SUBSTITUTION(node->mProfiler);
|
||||
return;
|
||||
case CT_NEW_WORD:
|
||||
case CT_NEW_WORD_SPACE_OMITTION:
|
||||
PROF_NEW_WORD(node->mProfiler);
|
||||
return;
|
||||
case CT_MATCH:
|
||||
|
@ -50,7 +50,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
case CT_TERMINAL:
|
||||
PROF_TERMINAL(node->mProfiler);
|
||||
return;
|
||||
case CT_SPACE_SUBSTITUTION:
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
PROF_SPACE_SUBSTITUTION(node->mProfiler);
|
||||
return;
|
||||
case CT_INSERTION:
|
||||
|
@ -107,16 +107,16 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
case CT_SUBSTITUTION:
|
||||
// only used for typing
|
||||
return weighting->getSubstitutionCost();
|
||||
case CT_NEW_WORD:
|
||||
return weighting->getNewWordCost(dicNode);
|
||||
case CT_NEW_WORD_SPACE_OMITTION:
|
||||
return weighting->getNewWordCost(traverseSession, dicNode);
|
||||
case CT_MATCH:
|
||||
return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
|
||||
case CT_COMPLETION:
|
||||
return weighting->getCompletionCost(traverseSession, dicNode);
|
||||
case CT_TERMINAL:
|
||||
return weighting->getTerminalSpatialCost(traverseSession, dicNode);
|
||||
case CT_SPACE_SUBSTITUTION:
|
||||
return weighting->getSpaceSubstitutionCost();
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
|
||||
case CT_INSERTION:
|
||||
return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
|
||||
case CT_TRANSPOSITION:
|
||||
|
@ -135,7 +135,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
return 0.0f;
|
||||
case CT_SUBSTITUTION:
|
||||
return 0.0f;
|
||||
case CT_NEW_WORD:
|
||||
case CT_NEW_WORD_SPACE_OMITTION:
|
||||
return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap);
|
||||
case CT_MATCH:
|
||||
return 0.0f;
|
||||
|
@ -147,8 +147,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
traverseSession->getOffsetDict(), dicNode, bigramCacheMap);
|
||||
return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
|
||||
}
|
||||
case CT_SPACE_SUBSTITUTION:
|
||||
return 0.0f;
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap);
|
||||
case CT_INSERTION:
|
||||
return 0.0f;
|
||||
case CT_TRANSPOSITION:
|
||||
|
@ -168,7 +168,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
case CT_SUBSTITUTION:
|
||||
// Should return true?
|
||||
return false;
|
||||
case CT_NEW_WORD:
|
||||
case CT_NEW_WORD_SPACE_OMITTION:
|
||||
return false;
|
||||
case CT_MATCH:
|
||||
return false;
|
||||
|
@ -176,7 +176,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
return false;
|
||||
case CT_TERMINAL:
|
||||
return false;
|
||||
case CT_SPACE_SUBSTITUTION:
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
return false;
|
||||
case CT_INSERTION:
|
||||
return true;
|
||||
|
@ -197,7 +197,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
return false;
|
||||
case CT_SUBSTITUTION:
|
||||
return false;
|
||||
case CT_NEW_WORD:
|
||||
case CT_NEW_WORD_SPACE_OMITTION:
|
||||
return false;
|
||||
case CT_MATCH:
|
||||
return weighting->isProximityDicNode(traverseSession, dicNode);
|
||||
|
@ -205,7 +205,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
return false;
|
||||
case CT_TERMINAL:
|
||||
return false;
|
||||
case CT_SPACE_SUBSTITUTION:
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
return false;
|
||||
case CT_INSERTION:
|
||||
return false;
|
||||
|
@ -224,7 +224,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
return 0;
|
||||
case CT_SUBSTITUTION:
|
||||
return 0;
|
||||
case CT_NEW_WORD:
|
||||
case CT_NEW_WORD_SPACE_OMITTION:
|
||||
return 0;
|
||||
case CT_MATCH:
|
||||
return 1;
|
||||
|
@ -232,7 +232,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
|||
return 0;
|
||||
case CT_TERMINAL:
|
||||
return 0;
|
||||
case CT_SPACE_SUBSTITUTION:
|
||||
case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
||||
return 1;
|
||||
case CT_INSERTION:
|
||||
return 2;
|
||||
|
|
|
@ -56,7 +56,8 @@ class Weighting {
|
|||
const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
|
||||
|
||||
virtual float getNewWordCost(const DicNode *const dicNode) const = 0;
|
||||
virtual float getNewWordCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const = 0;
|
||||
|
||||
virtual float getNewWordBigramCost(
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
|
||||
|
@ -76,7 +77,8 @@ class Weighting {
|
|||
|
||||
virtual float getSubstitutionCost() const = 0;
|
||||
|
||||
virtual float getSpaceSubstitutionCost() const = 0;
|
||||
virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const = 0;
|
||||
|
||||
Weighting() {}
|
||||
virtual ~Weighting() {}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "suggest/core/session/dic_traverse_session.h"
|
||||
|
||||
#include "binary_format.h"
|
||||
#include "defines.h"
|
||||
#include "dictionary.h"
|
||||
#include "dic_traverse_wrapper.h"
|
||||
|
@ -63,6 +64,7 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer;
|
|||
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
|
||||
int prevWordLength) {
|
||||
mDictionary = dictionary;
|
||||
mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict());
|
||||
if (!prevWord) {
|
||||
mPrevWordPos = NOT_VALID_WORD;
|
||||
return;
|
||||
|
|
|
@ -36,7 +36,8 @@ class DicTraverseSession {
|
|||
AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr)
|
||||
: mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0),
|
||||
mDictionary(0), mDicNodesCache(), mBigramCacheMap(),
|
||||
mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1) {
|
||||
mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1),
|
||||
mMultiWordCostMultiplier(1.0f) {
|
||||
// NOTE: mProximityInfoStates is an array of instances.
|
||||
// No need to initialize it explicitly here.
|
||||
}
|
||||
|
@ -52,6 +53,7 @@ class DicTraverseSession {
|
|||
const int maxPointerCount);
|
||||
void resetCache(const int nextActiveCacheSize, const int maxWords);
|
||||
|
||||
// TODO: Remove
|
||||
const uint8_t *getOffsetDict() const;
|
||||
int getDictFlags() const;
|
||||
|
||||
|
@ -150,6 +152,10 @@ class DicTraverseSession {
|
|||
return mProximityInfoStates[0].touchPositionCorrectionEnabled();
|
||||
}
|
||||
|
||||
float getMultiWordCostMultiplier() const {
|
||||
return mMultiWordCostMultiplier;
|
||||
}
|
||||
|
||||
private:
|
||||
DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession);
|
||||
// threshold to start caching
|
||||
|
@ -170,6 +176,11 @@ class DicTraverseSession {
|
|||
int mInputSize;
|
||||
bool mPartiallyCommited;
|
||||
int mMaxPointerCount;
|
||||
|
||||
/////////////////////////////////
|
||||
// Configuration per dictionary
|
||||
float mMultiWordCostMultiplier;
|
||||
|
||||
};
|
||||
} // namespace latinime
|
||||
#endif // LATINIME_DIC_TRAVERSE_SESSION_H
|
||||
|
|
|
@ -33,16 +33,9 @@
|
|||
namespace latinime {
|
||||
|
||||
// Initialization of class constants.
|
||||
const int Suggest::LOOKAHEAD_DIC_NODES_CACHE_SIZE = 25;
|
||||
const int Suggest::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
|
||||
const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2;
|
||||
const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f;
|
||||
const float Suggest::AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD = 0.6f;
|
||||
|
||||
const bool Suggest::CORRECT_SPACE_OMISSION = true;
|
||||
const bool Suggest::CORRECT_TRANSPOSITION = true;
|
||||
const bool Suggest::CORRECT_INSERTION = true;
|
||||
const bool Suggest::CORRECT_OMISSION_G = true;
|
||||
|
||||
/**
|
||||
* Returns a set of suggestions for the given input touch points. The commitPoint argument indicates
|
||||
|
@ -270,12 +263,8 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
|
|||
// latest touch point yet. These are needed to apply look-ahead correction operations
|
||||
// that require special handling of the latest touch point. For example, with insertions
|
||||
// (e.g., "thiis" -> "this") the latest touch point should not be consumed at all.
|
||||
if (CORRECT_TRANSPOSITION) {
|
||||
processDicNodeAsTransposition(traverseSession, &dicNode);
|
||||
}
|
||||
if (CORRECT_INSERTION) {
|
||||
processDicNodeAsInsertion(traverseSession, &dicNode);
|
||||
}
|
||||
processDicNodeAsTransposition(traverseSession, &dicNode);
|
||||
processDicNodeAsInsertion(traverseSession, &dicNode);
|
||||
} else { // !isLookAheadCorrection
|
||||
// Only consider typing error corrections if the normalized compound distance is
|
||||
// below a spatial distance threshold.
|
||||
|
@ -531,13 +520,10 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode
|
|||
DicNode newDicNode;
|
||||
DicNodeUtils::initAsRootWithPreviousWord(traverseSession->getDicRootPos(),
|
||||
traverseSession->getOffsetDict(), dicNode, &newDicNode);
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_NEW_WORD, traverseSession, dicNode,
|
||||
const CorrectionType correctionType = spaceSubstitution ?
|
||||
CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION;
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode,
|
||||
&newDicNode, traverseSession->getBigramCacheMap());
|
||||
if (spaceSubstitution) {
|
||||
// Merge this with CT_NEW_WORD
|
||||
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SPACE_SUBSTITUTION,
|
||||
traverseSession, 0, &newDicNode, 0 /* bigramCacheMap */);
|
||||
}
|
||||
traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode);
|
||||
}
|
||||
} // namespace latinime
|
||||
|
|
|
@ -76,31 +76,16 @@ class Suggest : public SuggestInterface {
|
|||
void processDicNodeAsMatch(DicTraverseSession *traverseSession,
|
||||
DicNode *childDicNode) const;
|
||||
|
||||
// Dic nodes cache size for lookahead (autocompletion)
|
||||
static const int LOOKAHEAD_DIC_NODES_CACHE_SIZE;
|
||||
// Max characters to lookahead
|
||||
static const int MAX_LOOKAHEAD;
|
||||
// Inputs longer than this will autocorrect if the suggestion is multi-word
|
||||
static const int MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT;
|
||||
static const int MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE;
|
||||
// Base value for converting costs into scores (low so will not autocorrect without classifier)
|
||||
static const float BASE_OUTPUT_SCORE;
|
||||
|
||||
// Threshold for autocorrection classifier
|
||||
static const float AUTOCORRECT_CLASSIFICATION_THRESHOLD;
|
||||
// Threshold for computing the language model feature for autocorrect classification
|
||||
static const float AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD;
|
||||
|
||||
// Typing error correction settings
|
||||
static const bool CORRECT_SPACE_OMISSION;
|
||||
static const bool CORRECT_TRANSPOSITION;
|
||||
static const bool CORRECT_INSERTION;
|
||||
|
||||
const Traversal *const TRAVERSAL;
|
||||
const Scoring *const SCORING;
|
||||
const Weighting *const WEIGHTING;
|
||||
|
||||
static const bool CORRECT_OMISSION_G;
|
||||
};
|
||||
} // namespace latinime
|
||||
#endif // LATINIME_SUGGEST_IMPL_H
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
namespace latinime {
|
||||
const bool TypingTraversal::CORRECT_OMISSION = true;
|
||||
const bool TypingTraversal::CORRECT_SPACE_SUBSTITUTION = true;
|
||||
const bool TypingTraversal::CORRECT_SPACE_OMISSION = true;
|
||||
const bool TypingTraversal::CORRECT_NEW_WORD_SPACE_SUBSTITUTION = true;
|
||||
const bool TypingTraversal::CORRECT_NEW_WORD_SPACE_OMISSION = true;
|
||||
const TypingTraversal TypingTraversal::sInstance;
|
||||
} // namespace latinime
|
||||
|
|
|
@ -66,7 +66,7 @@ class TypingTraversal : public Traversal {
|
|||
|
||||
AK_FORCE_INLINE bool isSpaceSubstitutionTerminal(
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
|
||||
if (!CORRECT_SPACE_SUBSTITUTION) {
|
||||
if (!CORRECT_NEW_WORD_SPACE_SUBSTITUTION) {
|
||||
return false;
|
||||
}
|
||||
if (!canDoLookAheadCorrection(traverseSession, dicNode)) {
|
||||
|
@ -80,7 +80,7 @@ class TypingTraversal : public Traversal {
|
|||
|
||||
AK_FORCE_INLINE bool isSpaceOmissionTerminal(
|
||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
|
||||
if (!CORRECT_SPACE_OMISSION) {
|
||||
if (!CORRECT_NEW_WORD_SPACE_OMISSION) {
|
||||
return false;
|
||||
}
|
||||
const int inputSize = traverseSession->getInputSize();
|
||||
|
@ -173,8 +173,8 @@ class TypingTraversal : public Traversal {
|
|||
private:
|
||||
DISALLOW_COPY_AND_ASSIGN(TypingTraversal);
|
||||
static const bool CORRECT_OMISSION;
|
||||
static const bool CORRECT_SPACE_SUBSTITUTION;
|
||||
static const bool CORRECT_SPACE_OMISSION;
|
||||
static const bool CORRECT_NEW_WORD_SPACE_SUBSTITUTION;
|
||||
static const bool CORRECT_NEW_WORD_SPACE_OMISSION;
|
||||
static const TypingTraversal sInstance;
|
||||
|
||||
TypingTraversal() {}
|
||||
|
|
|
@ -128,10 +128,12 @@ class TypingWeighting : public Weighting {
|
|||
return cost + weightedDistance;
|
||||
}
|
||||
|
||||
float getNewWordCost(const DicNode *const dicNode) const {
|
||||
float getNewWordCost(const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const {
|
||||
const bool isCapitalized = dicNode->isCapitalized();
|
||||
return isCapitalized ?
|
||||
const float cost = isCapitalized ?
|
||||
ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD;
|
||||
return cost * traverseSession->getMultiWordCostMultiplier();
|
||||
}
|
||||
|
||||
float getNewWordBigramCost(
|
||||
|
@ -183,8 +185,13 @@ class TypingWeighting : public Weighting {
|
|||
return ScoringParams::SUBSTITUTION_COST;
|
||||
}
|
||||
|
||||
AK_FORCE_INLINE float getSpaceSubstitutionCost() const {
|
||||
return ScoringParams::SPACE_SUBSTITUTION_COST;
|
||||
AK_FORCE_INLINE float getSpaceSubstitutionCost(
|
||||
const DicTraverseSession *const traverseSession,
|
||||
const DicNode *const dicNode) const {
|
||||
const bool isCapitalized = dicNode->isCapitalized();
|
||||
const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + (isCapitalized ?
|
||||
ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD);
|
||||
return cost * traverseSession->getMultiWordCostMultiplier();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in a new issue