am 252412d7: Use additional multi-word cost per language (for Russian)

* commit '252412d7eb4573f91588b06b0fe49ef9f0ac38ac':
  Use additional multi-word cost per language (for Russian)
This commit is contained in:
Satoshi Kataoka 2013-04-16 00:49:56 -07:00 committed by Android Git Automerger
commit 93a429a7b5
11 changed files with 69 additions and 65 deletions

View file

@ -92,6 +92,7 @@ class BinaryFormat {
const int unigramProbability, const int bigramProbability); const int unigramProbability, const int bigramProbability);
static int getProbability(const int position, const std::map<int, int> *bigramMap, static int getProbability(const int position, const std::map<int, int> *bigramMap,
const uint8_t *bigramFilter, const int unigramProbability); const uint8_t *bigramFilter, const int unigramProbability);
static float getMultiWordCostMultiplier(const uint8_t *const dict);
// Flags for special processing // Flags for special processing
// Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or
@ -241,6 +242,17 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *
return ((msb & 0x7F) << 8) | dict[(*pos)++]; return ((msb & 0x7F) << 8) | dict[(*pos)++];
} }
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) {
const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE");
if (headerValue == S_INT_MIN) {
return 1.0f;
}
if (headerValue <= 0) {
return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
}
return 100.0f / static_cast<float>(headerValue);
}
inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) { inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) {
return dict[(*pos)++]; return dict[(*pos)++];
} }

View file

@ -424,10 +424,9 @@ typedef enum {
CT_OMISSION, CT_OMISSION,
CT_INSERTION, CT_INSERTION,
CT_TRANSPOSITION, CT_TRANSPOSITION,
CT_SPACE_SUBSTITUTION,
CT_SPACE_OMISSION,
CT_COMPLETION, CT_COMPLETION,
CT_TERMINAL, CT_TERMINAL,
CT_NEW_WORD, CT_NEW_WORD_SPACE_OMITTION,
CT_NEW_WORD_SPACE_SUBSTITUTION,
} CorrectionType; } CorrectionType;
#endif // LATINIME_DEFINES_H #endif // LATINIME_DEFINES_H

View file

@ -38,7 +38,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_SUBSTITUTION: case CT_SUBSTITUTION:
PROF_SUBSTITUTION(node->mProfiler); PROF_SUBSTITUTION(node->mProfiler);
return; return;
case CT_NEW_WORD: case CT_NEW_WORD_SPACE_OMITTION:
PROF_NEW_WORD(node->mProfiler); PROF_NEW_WORD(node->mProfiler);
return; return;
case CT_MATCH: case CT_MATCH:
@ -50,7 +50,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_TERMINAL: case CT_TERMINAL:
PROF_TERMINAL(node->mProfiler); PROF_TERMINAL(node->mProfiler);
return; return;
case CT_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
PROF_SPACE_SUBSTITUTION(node->mProfiler); PROF_SPACE_SUBSTITUTION(node->mProfiler);
return; return;
case CT_INSERTION: case CT_INSERTION:
@ -107,16 +107,16 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_SUBSTITUTION: case CT_SUBSTITUTION:
// only used for typing // only used for typing
return weighting->getSubstitutionCost(); return weighting->getSubstitutionCost();
case CT_NEW_WORD: case CT_NEW_WORD_SPACE_OMITTION:
return weighting->getNewWordCost(dicNode); return weighting->getNewWordCost(traverseSession, dicNode);
case CT_MATCH: case CT_MATCH:
return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
case CT_COMPLETION: case CT_COMPLETION:
return weighting->getCompletionCost(traverseSession, dicNode); return weighting->getCompletionCost(traverseSession, dicNode);
case CT_TERMINAL: case CT_TERMINAL:
return weighting->getTerminalSpatialCost(traverseSession, dicNode); return weighting->getTerminalSpatialCost(traverseSession, dicNode);
case CT_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
return weighting->getSpaceSubstitutionCost(); return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
case CT_INSERTION: case CT_INSERTION:
return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode); return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
case CT_TRANSPOSITION: case CT_TRANSPOSITION:
@ -135,7 +135,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return 0.0f; return 0.0f;
case CT_SUBSTITUTION: case CT_SUBSTITUTION:
return 0.0f; return 0.0f;
case CT_NEW_WORD: case CT_NEW_WORD_SPACE_OMITTION:
return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap);
case CT_MATCH: case CT_MATCH:
return 0.0f; return 0.0f;
@ -147,8 +147,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
traverseSession->getOffsetDict(), dicNode, bigramCacheMap); traverseSession->getOffsetDict(), dicNode, bigramCacheMap);
return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
} }
case CT_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
return 0.0f; return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap);
case CT_INSERTION: case CT_INSERTION:
return 0.0f; return 0.0f;
case CT_TRANSPOSITION: case CT_TRANSPOSITION:
@ -168,7 +168,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_SUBSTITUTION: case CT_SUBSTITUTION:
// Should return true? // Should return true?
return false; return false;
case CT_NEW_WORD: case CT_NEW_WORD_SPACE_OMITTION:
return false; return false;
case CT_MATCH: case CT_MATCH:
return false; return false;
@ -176,7 +176,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return false; return false;
case CT_TERMINAL: case CT_TERMINAL:
return false; return false;
case CT_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
return false; return false;
case CT_INSERTION: case CT_INSERTION:
return true; return true;
@ -197,7 +197,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return false; return false;
case CT_SUBSTITUTION: case CT_SUBSTITUTION:
return false; return false;
case CT_NEW_WORD: case CT_NEW_WORD_SPACE_OMITTION:
return false; return false;
case CT_MATCH: case CT_MATCH:
return weighting->isProximityDicNode(traverseSession, dicNode); return weighting->isProximityDicNode(traverseSession, dicNode);
@ -205,7 +205,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return false; return false;
case CT_TERMINAL: case CT_TERMINAL:
return false; return false;
case CT_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
return false; return false;
case CT_INSERTION: case CT_INSERTION:
return false; return false;
@ -224,7 +224,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return 0; return 0;
case CT_SUBSTITUTION: case CT_SUBSTITUTION:
return 0; return 0;
case CT_NEW_WORD: case CT_NEW_WORD_SPACE_OMITTION:
return 0; return 0;
case CT_MATCH: case CT_MATCH:
return 1; return 1;
@ -232,7 +232,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return 0; return 0;
case CT_TERMINAL: case CT_TERMINAL:
return 0; return 0;
case CT_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
return 1; return 1;
case CT_INSERTION: case CT_INSERTION:
return 2; return 2;

View file

@ -56,7 +56,8 @@ class Weighting {
const DicTraverseSession *const traverseSession, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
virtual float getNewWordCost(const DicNode *const dicNode) const = 0; virtual float getNewWordCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual float getNewWordBigramCost( virtual float getNewWordBigramCost(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode, const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
@ -76,7 +77,8 @@ class Weighting {
virtual float getSubstitutionCost() const = 0; virtual float getSubstitutionCost() const = 0;
virtual float getSpaceSubstitutionCost() const = 0; virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
Weighting() {} Weighting() {}
virtual ~Weighting() {} virtual ~Weighting() {}

View file

@ -16,6 +16,7 @@
#include "suggest/core/session/dic_traverse_session.h" #include "suggest/core/session/dic_traverse_session.h"
#include "binary_format.h"
#include "defines.h" #include "defines.h"
#include "dictionary.h" #include "dictionary.h"
#include "dic_traverse_wrapper.h" #include "dic_traverse_wrapper.h"
@ -63,6 +64,7 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer;
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
int prevWordLength) { int prevWordLength) {
mDictionary = dictionary; mDictionary = dictionary;
mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict());
if (!prevWord) { if (!prevWord) {
mPrevWordPos = NOT_VALID_WORD; mPrevWordPos = NOT_VALID_WORD;
return; return;

View file

@ -36,7 +36,8 @@ class DicTraverseSession {
AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr)
: mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0),
mDictionary(0), mDicNodesCache(), mBigramCacheMap(), mDictionary(0), mDicNodesCache(), mBigramCacheMap(),
mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1) { mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1),
mMultiWordCostMultiplier(1.0f) {
// NOTE: mProximityInfoStates is an array of instances. // NOTE: mProximityInfoStates is an array of instances.
// No need to initialize it explicitly here. // No need to initialize it explicitly here.
} }
@ -52,6 +53,7 @@ class DicTraverseSession {
const int maxPointerCount); const int maxPointerCount);
void resetCache(const int nextActiveCacheSize, const int maxWords); void resetCache(const int nextActiveCacheSize, const int maxWords);
// TODO: Remove
const uint8_t *getOffsetDict() const; const uint8_t *getOffsetDict() const;
int getDictFlags() const; int getDictFlags() const;
@ -150,6 +152,10 @@ class DicTraverseSession {
return mProximityInfoStates[0].touchPositionCorrectionEnabled(); return mProximityInfoStates[0].touchPositionCorrectionEnabled();
} }
float getMultiWordCostMultiplier() const {
return mMultiWordCostMultiplier;
}
private: private:
DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession);
// threshold to start caching // threshold to start caching
@ -170,6 +176,11 @@ class DicTraverseSession {
int mInputSize; int mInputSize;
bool mPartiallyCommited; bool mPartiallyCommited;
int mMaxPointerCount; int mMaxPointerCount;
/////////////////////////////////
// Configuration per dictionary
float mMultiWordCostMultiplier;
}; };
} // namespace latinime } // namespace latinime
#endif // LATINIME_DIC_TRAVERSE_SESSION_H #endif // LATINIME_DIC_TRAVERSE_SESSION_H

View file

@ -33,16 +33,9 @@
namespace latinime { namespace latinime {
// Initialization of class constants. // Initialization of class constants.
const int Suggest::LOOKAHEAD_DIC_NODES_CACHE_SIZE = 25;
const int Suggest::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; const int Suggest::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2;
const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f; const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f;
const float Suggest::AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD = 0.6f;
const bool Suggest::CORRECT_SPACE_OMISSION = true;
const bool Suggest::CORRECT_TRANSPOSITION = true;
const bool Suggest::CORRECT_INSERTION = true;
const bool Suggest::CORRECT_OMISSION_G = true;
/** /**
* Returns a set of suggestions for the given input touch points. The commitPoint argument indicates * Returns a set of suggestions for the given input touch points. The commitPoint argument indicates
@ -270,12 +263,8 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
// latest touch point yet. These are needed to apply look-ahead correction operations // latest touch point yet. These are needed to apply look-ahead correction operations
// that require special handling of the latest touch point. For example, with insertions // that require special handling of the latest touch point. For example, with insertions
// (e.g., "thiis" -> "this") the latest touch point should not be consumed at all. // (e.g., "thiis" -> "this") the latest touch point should not be consumed at all.
if (CORRECT_TRANSPOSITION) {
processDicNodeAsTransposition(traverseSession, &dicNode); processDicNodeAsTransposition(traverseSession, &dicNode);
}
if (CORRECT_INSERTION) {
processDicNodeAsInsertion(traverseSession, &dicNode); processDicNodeAsInsertion(traverseSession, &dicNode);
}
} else { // !isLookAheadCorrection } else { // !isLookAheadCorrection
// Only consider typing error corrections if the normalized compound distance is // Only consider typing error corrections if the normalized compound distance is
// below a spatial distance threshold. // below a spatial distance threshold.
@ -531,13 +520,10 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode
DicNode newDicNode; DicNode newDicNode;
DicNodeUtils::initAsRootWithPreviousWord(traverseSession->getDicRootPos(), DicNodeUtils::initAsRootWithPreviousWord(traverseSession->getDicRootPos(),
traverseSession->getOffsetDict(), dicNode, &newDicNode); traverseSession->getOffsetDict(), dicNode, &newDicNode);
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_NEW_WORD, traverseSession, dicNode, const CorrectionType correctionType = spaceSubstitution ?
CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION;
Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode,
&newDicNode, traverseSession->getBigramCacheMap()); &newDicNode, traverseSession->getBigramCacheMap());
if (spaceSubstitution) {
// Merge this with CT_NEW_WORD
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SPACE_SUBSTITUTION,
traverseSession, 0, &newDicNode, 0 /* bigramCacheMap */);
}
traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode);
} }
} // namespace latinime } // namespace latinime

View file

@ -76,31 +76,16 @@ class Suggest : public SuggestInterface {
void processDicNodeAsMatch(DicTraverseSession *traverseSession, void processDicNodeAsMatch(DicTraverseSession *traverseSession,
DicNode *childDicNode) const; DicNode *childDicNode) const;
// Dic nodes cache size for lookahead (autocompletion)
static const int LOOKAHEAD_DIC_NODES_CACHE_SIZE;
// Max characters to lookahead
static const int MAX_LOOKAHEAD;
// Inputs longer than this will autocorrect if the suggestion is multi-word // Inputs longer than this will autocorrect if the suggestion is multi-word
static const int MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT; static const int MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT;
static const int MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE; static const int MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE;
// Base value for converting costs into scores (low so will not autocorrect without classifier)
static const float BASE_OUTPUT_SCORE;
// Threshold for autocorrection classifier // Threshold for autocorrection classifier
static const float AUTOCORRECT_CLASSIFICATION_THRESHOLD; static const float AUTOCORRECT_CLASSIFICATION_THRESHOLD;
// Threshold for computing the language model feature for autocorrect classification
static const float AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD;
// Typing error correction settings
static const bool CORRECT_SPACE_OMISSION;
static const bool CORRECT_TRANSPOSITION;
static const bool CORRECT_INSERTION;
const Traversal *const TRAVERSAL; const Traversal *const TRAVERSAL;
const Scoring *const SCORING; const Scoring *const SCORING;
const Weighting *const WEIGHTING; const Weighting *const WEIGHTING;
static const bool CORRECT_OMISSION_G;
}; };
} // namespace latinime } // namespace latinime
#endif // LATINIME_SUGGEST_IMPL_H #endif // LATINIME_SUGGEST_IMPL_H

View file

@ -18,7 +18,7 @@
namespace latinime { namespace latinime {
const bool TypingTraversal::CORRECT_OMISSION = true; const bool TypingTraversal::CORRECT_OMISSION = true;
const bool TypingTraversal::CORRECT_SPACE_SUBSTITUTION = true; const bool TypingTraversal::CORRECT_NEW_WORD_SPACE_SUBSTITUTION = true;
const bool TypingTraversal::CORRECT_SPACE_OMISSION = true; const bool TypingTraversal::CORRECT_NEW_WORD_SPACE_OMISSION = true;
const TypingTraversal TypingTraversal::sInstance; const TypingTraversal TypingTraversal::sInstance;
} // namespace latinime } // namespace latinime

View file

@ -66,7 +66,7 @@ class TypingTraversal : public Traversal {
AK_FORCE_INLINE bool isSpaceSubstitutionTerminal( AK_FORCE_INLINE bool isSpaceSubstitutionTerminal(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
if (!CORRECT_SPACE_SUBSTITUTION) { if (!CORRECT_NEW_WORD_SPACE_SUBSTITUTION) {
return false; return false;
} }
if (!canDoLookAheadCorrection(traverseSession, dicNode)) { if (!canDoLookAheadCorrection(traverseSession, dicNode)) {
@ -80,7 +80,7 @@ class TypingTraversal : public Traversal {
AK_FORCE_INLINE bool isSpaceOmissionTerminal( AK_FORCE_INLINE bool isSpaceOmissionTerminal(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
if (!CORRECT_SPACE_OMISSION) { if (!CORRECT_NEW_WORD_SPACE_OMISSION) {
return false; return false;
} }
const int inputSize = traverseSession->getInputSize(); const int inputSize = traverseSession->getInputSize();
@ -173,8 +173,8 @@ class TypingTraversal : public Traversal {
private: private:
DISALLOW_COPY_AND_ASSIGN(TypingTraversal); DISALLOW_COPY_AND_ASSIGN(TypingTraversal);
static const bool CORRECT_OMISSION; static const bool CORRECT_OMISSION;
static const bool CORRECT_SPACE_SUBSTITUTION; static const bool CORRECT_NEW_WORD_SPACE_SUBSTITUTION;
static const bool CORRECT_SPACE_OMISSION; static const bool CORRECT_NEW_WORD_SPACE_OMISSION;
static const TypingTraversal sInstance; static const TypingTraversal sInstance;
TypingTraversal() {} TypingTraversal() {}

View file

@ -128,10 +128,12 @@ class TypingWeighting : public Weighting {
return cost + weightedDistance; return cost + weightedDistance;
} }
float getNewWordCost(const DicNode *const dicNode) const { float getNewWordCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
const bool isCapitalized = dicNode->isCapitalized(); const bool isCapitalized = dicNode->isCapitalized();
return isCapitalized ? const float cost = isCapitalized ?
ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD; ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD;
return cost * traverseSession->getMultiWordCostMultiplier();
} }
float getNewWordBigramCost( float getNewWordBigramCost(
@ -183,8 +185,13 @@ class TypingWeighting : public Weighting {
return ScoringParams::SUBSTITUTION_COST; return ScoringParams::SUBSTITUTION_COST;
} }
AK_FORCE_INLINE float getSpaceSubstitutionCost() const { AK_FORCE_INLINE float getSpaceSubstitutionCost(
return ScoringParams::SPACE_SUBSTITUTION_COST; const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
const bool isCapitalized = dicNode->isCapitalized();
const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + (isCapitalized ?
ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD);
return cost * traverseSession->getMultiWordCostMultiplier();
} }
private: private: