Merge "Support terminal insertion error correction"

main
Satoshi Kataoka 2013-07-26 02:25:45 +00:00 committed by Android (Google) Code Review
commit ea2ab41c4f
9 changed files with 43 additions and 10 deletions

View File

@ -381,6 +381,7 @@ typedef enum {
CT_TRANSPOSITION, CT_TRANSPOSITION,
CT_COMPLETION, CT_COMPLETION,
CT_TERMINAL, CT_TERMINAL,
CT_TERMINAL_INSERTION,
// Create new word with space omission // Create new word with space omission
CT_NEW_WORD_SPACE_OMITTION, CT_NEW_WORD_SPACE_OMITTION,
// Create new word with space substitution // Create new word with space substitution

View File

@ -31,6 +31,7 @@
#define PROF_TRANSPOSITION(profiler) profiler.profTransposition() #define PROF_TRANSPOSITION(profiler) profiler.profTransposition()
#define PROF_NEARESTKEY(profiler) profiler.profNearestKey() #define PROF_NEARESTKEY(profiler) profiler.profNearestKey()
#define PROF_TERMINAL(profiler) profiler.profTerminal() #define PROF_TERMINAL(profiler) profiler.profTerminal()
#define PROF_TERMINAL_INSERTION(profiler) profiler.profTerminalInsertion()
#define PROF_NEW_WORD(profiler) profiler.profNewWord() #define PROF_NEW_WORD(profiler) profiler.profNewWord()
#define PROF_NEW_WORD_BIGRAM(profiler) profiler.profNewWordBigram() #define PROF_NEW_WORD_BIGRAM(profiler) profiler.profNewWordBigram()
#define PROF_NODE_RESET(profiler) profiler.reset() #define PROF_NODE_RESET(profiler) profiler.reset()
@ -47,6 +48,7 @@
#define PROF_TRANSPOSITION(profiler) #define PROF_TRANSPOSITION(profiler)
#define PROF_NEARESTKEY(profiler) #define PROF_NEARESTKEY(profiler)
#define PROF_TERMINAL(profiler) #define PROF_TERMINAL(profiler)
#define PROF_TERMINAL_INSERTION(profiler)
#define PROF_NEW_WORD(profiler) #define PROF_NEW_WORD(profiler)
#define PROF_NEW_WORD_BIGRAM(profiler) #define PROF_NEW_WORD_BIGRAM(profiler)
#define PROF_NODE_RESET(profiler) #define PROF_NODE_RESET(profiler)
@ -62,7 +64,7 @@ class DicNodeProfiler {
: mProfOmission(0), mProfInsertion(0), mProfTransposition(0), : mProfOmission(0), mProfInsertion(0), mProfTransposition(0),
mProfAdditionalProximity(0), mProfSubstitution(0), mProfAdditionalProximity(0), mProfSubstitution(0),
mProfSpaceSubstitution(0), mProfSpaceOmission(0), mProfSpaceSubstitution(0), mProfSpaceOmission(0),
mProfMatch(0), mProfCompletion(0), mProfTerminal(0), mProfMatch(0), mProfCompletion(0), mProfTerminal(0), mProfTerminalInsertion(0),
mProfNearestKey(0), mProfNewWord(0), mProfNewWordBigram(0) {} mProfNearestKey(0), mProfNewWord(0), mProfNewWordBigram(0) {}
int mProfOmission; int mProfOmission;
@ -75,6 +77,7 @@ class DicNodeProfiler {
int mProfMatch; int mProfMatch;
int mProfCompletion; int mProfCompletion;
int mProfTerminal; int mProfTerminal;
int mProfTerminalInsertion;
int mProfNearestKey; int mProfNearestKey;
int mProfNewWord; int mProfNewWord;
int mProfNewWordBigram; int mProfNewWordBigram;
@ -123,6 +126,10 @@ class DicNodeProfiler {
++mProfTerminal; ++mProfTerminal;
} }
void profTerminalInsertion() {
++mProfTerminalInsertion;
}
void profNewWord() { void profNewWord() {
++mProfNewWord; ++mProfNewWord;
} }

View File

@ -50,6 +50,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_TERMINAL: case CT_TERMINAL:
PROF_TERMINAL(node->mProfiler); PROF_TERMINAL(node->mProfiler);
return; return;
case CT_TERMINAL_INSERTION:
PROF_TERMINAL_INSERTION(node->mProfiler);
return;
case CT_NEW_WORD_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
PROF_SPACE_SUBSTITUTION(node->mProfiler); PROF_SPACE_SUBSTITUTION(node->mProfiler);
return; return;
@ -113,6 +116,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return weighting->getCompletionCost(traverseSession, dicNode); return weighting->getCompletionCost(traverseSession, dicNode);
case CT_TERMINAL: case CT_TERMINAL:
return weighting->getTerminalSpatialCost(traverseSession, dicNode); return weighting->getTerminalSpatialCost(traverseSession, dicNode);
case CT_TERMINAL_INSERTION:
return weighting->getTerminalInsertionCost(traverseSession, dicNode);
case CT_NEW_WORD_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
return weighting->getSpaceSubstitutionCost(traverseSession, dicNode); return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
case CT_INSERTION: case CT_INSERTION:
@ -146,6 +151,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
traverseSession->getBinaryDictionaryInfo(), dicNode, multiBigramMap); traverseSession->getBinaryDictionaryInfo(), dicNode, multiBigramMap);
return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
} }
case CT_TERMINAL_INSERTION:
return 0.0f;
case CT_NEW_WORD_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
return weighting->getNewWordBigramLanguageCost( return weighting->getNewWordBigramLanguageCost(
traverseSession, parentDicNode, multiBigramMap); traverseSession, parentDicNode, multiBigramMap);
@ -163,9 +170,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_OMISSION: case CT_OMISSION:
return 0; return 0;
case CT_ADDITIONAL_PROXIMITY: case CT_ADDITIONAL_PROXIMITY:
return 0; return 0; /* 0 because CT_MATCH will be called */
case CT_SUBSTITUTION: case CT_SUBSTITUTION:
return 0; return 0; /* 0 because CT_MATCH will be called */
case CT_NEW_WORD_SPACE_OMITTION: case CT_NEW_WORD_SPACE_OMITTION:
return 0; return 0;
case CT_MATCH: case CT_MATCH:
@ -174,12 +181,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return 1; return 1;
case CT_TERMINAL: case CT_TERMINAL:
return 0; return 0;
case CT_TERMINAL_INSERTION:
return 1;
case CT_NEW_WORD_SPACE_SUBSTITUTION: case CT_NEW_WORD_SPACE_SUBSTITUTION:
return 1; return 1;
case CT_INSERTION: case CT_INSERTION:
return 2; return 2; /* look ahead + skip the current char */
case CT_TRANSPOSITION: case CT_TRANSPOSITION:
return 2; return 2; /* look ahead + skip the current char */
default: default:
return 0; return 0;
} }

View File

@ -67,6 +67,10 @@ class Weighting {
const DicTraverseSession *const traverseSession, const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0; const DicNode *const dicNode) const = 0;
virtual float getTerminalInsertionCost(
const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual float getTerminalLanguageCost( virtual float getTerminalLanguageCost(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode, const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
float dicNodeLanguageImprobability) const = 0; float dicNodeLanguageImprobability) const = 0;

View File

@ -365,17 +365,17 @@ void Suggest::processTerminalDicNode(
if (!dicNode->isTerminalWordNode()) { if (!dicNode->isTerminalWordNode()) {
return; return;
} }
if (TRAVERSAL->needsToTraverseAllUserInput()
&& dicNode->getInputIndex(0) < traverseSession->getInputSize()) {
return;
}
if (dicNode->shouldBeFilterdBySafetyNetForBigram()) { if (dicNode->shouldBeFilterdBySafetyNetForBigram()) {
return; return;
} }
// Create a non-cached node here. // Create a non-cached node here.
DicNode terminalDicNode; DicNode terminalDicNode;
DicNodeUtils::initByCopy(dicNode, &terminalDicNode); DicNodeUtils::initByCopy(dicNode, &terminalDicNode);
if (TRAVERSAL->needsToTraverseAllUserInput()
&& dicNode->getInputIndex(0) < traverseSession->getInputSize()) {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0,
&terminalDicNode, traverseSession->getMultiBigramMap());
}
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0, Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0,
&terminalDicNode, traverseSession->getMultiBigramMap()); &terminalDicNode, traverseSession->getMultiBigramMap());
traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode); traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode);

View File

@ -34,6 +34,7 @@ const float ScoringParams::OMISSION_COST = 0.458f;
const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f; const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f;
const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f; const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f;
const float ScoringParams::INSERTION_COST = 0.730f; const float ScoringParams::INSERTION_COST = 0.730f;
const float ScoringParams::TERMINAL_INSERTION_COST = 0.93f;
const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f; const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f;
const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.70f; const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.70f;
const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f;

View File

@ -42,6 +42,7 @@ class ScoringParams {
static const float OMISSION_COST_SAME_CHAR; static const float OMISSION_COST_SAME_CHAR;
static const float OMISSION_COST_FIRST_CHAR; static const float OMISSION_COST_FIRST_CHAR;
static const float INSERTION_COST; static const float INSERTION_COST;
static const float TERMINAL_INSERTION_COST;
static const float INSERTION_COST_SAME_CHAR; static const float INSERTION_COST_SAME_CHAR;
static const float INSERTION_COST_PROXIMITY_CHAR; static const float INSERTION_COST_PROXIMITY_CHAR;
static const float INSERTION_COST_FIRST_CHAR; static const float INSERTION_COST_FIRST_CHAR;

View File

@ -44,6 +44,7 @@ ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType,
break; break;
case CT_SUBSTITUTION: case CT_SUBSTITUTION:
case CT_INSERTION: case CT_INSERTION:
case CT_TERMINAL_INSERTION:
case CT_TRANSPOSITION: case CT_TRANSPOSITION:
return ET_EDIT_CORRECTION; return ET_EDIT_CORRECTION;
case CT_NEW_WORD_SPACE_OMITTION: case CT_NEW_WORD_SPACE_OMITTION:

View File

@ -175,6 +175,15 @@ class TypingWeighting : public Weighting {
return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
} }
float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
const int inputIndex = dicNode->getInputIndex(0);
const int inputSize = traverseSession->getInputSize();
ASSERT(inputIndex < inputSize);
// TODO: Implement more efficient logic
return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
}
AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
return false; return false;
} }