parent
be06bce18b
commit
866a6ced57
|
@ -187,7 +187,7 @@ void BigramDictionary::fillBigramAddressToProbabilityMapAndFilter(const int *pre
|
||||||
&pos);
|
&pos);
|
||||||
(*map)[bigramPos] = probability;
|
(*map)[bigramPos] = probability;
|
||||||
setInFilter(filter, bigramPos);
|
setInFilter(filter, bigramPos);
|
||||||
} while (0 != (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags));
|
} while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const {
|
bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const {
|
||||||
|
|
|
@ -271,7 +271,7 @@ namespace latinime {
|
||||||
return probability;
|
return probability;
|
||||||
}
|
}
|
||||||
count++;
|
count++;
|
||||||
} while ((0 != (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags))
|
} while ((BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags)
|
||||||
&& count < MAX_BIGRAMS_CONSIDERED_PER_CONTEXT);
|
&& count < MAX_BIGRAMS_CONSIDERED_PER_CONTEXT);
|
||||||
if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
|
if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
|
||||||
// TODO: does this -1 mean NOT_VALID_WORD?
|
// TODO: does this -1 mean NOT_VALID_WORD?
|
||||||
|
|
|
@ -29,16 +29,14 @@ class Scoring {
|
||||||
public:
|
public:
|
||||||
virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
|
virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
|
||||||
const bool forceCommit) const = 0;
|
const bool forceCommit) const = 0;
|
||||||
virtual bool getMostProbableString(
|
virtual bool getMostProbableString(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession, const int terminalSize,
|
const int terminalSize, const float languageWeight, int *const outputCodePoints,
|
||||||
const float languageWeight, int *const outputCodePoints, int *const type,
|
int *const type, int *const freq) const = 0;
|
||||||
int *const freq) const = 0;
|
|
||||||
virtual void safetyNetForMostProbableString(const int terminalSize,
|
virtual void safetyNetForMostProbableString(const int terminalSize,
|
||||||
const int maxScore, int *const outputCodePoints, int *const frequencies) const = 0;
|
const int maxScore, int *const outputCodePoints, int *const frequencies) const = 0;
|
||||||
// TODO: Make more generic
|
// TODO: Make more generic
|
||||||
virtual void searchWordWithDoubleLetter(DicNode *terminals,
|
virtual void searchWordWithDoubleLetter(DicNode *terminals, const int terminalSize,
|
||||||
const int terminalSize, int *doubleLetterTerminalIndex,
|
int *doubleLetterTerminalIndex, DoubleLetterLevel *doubleLetterLevel) const = 0;
|
||||||
DoubleLetterLevel *doubleLetterLevel) const = 0;
|
|
||||||
virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession,
|
virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession,
|
||||||
DicNode *const terminals, const int size) const = 0;
|
DicNode *const terminals, const int size) const = 0;
|
||||||
virtual float getDoubleLetterDemotionDistanceCost(const int terminalIndex,
|
virtual float getDoubleLetterDemotionDistanceCost(const int terminalIndex,
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include "defines.h"
|
#include "defines.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
class Traversal;
|
class Traversal;
|
||||||
class Scoring;
|
class Scoring;
|
||||||
class Weighting;
|
class Weighting;
|
||||||
|
|
|
@ -39,9 +39,8 @@ class Traversal {
|
||||||
const DicNode *const dicNode) const = 0;
|
const DicNode *const dicNode) const = 0;
|
||||||
virtual bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession,
|
virtual bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession,
|
||||||
const DicNode *const dicNode) const = 0;
|
const DicNode *const dicNode) const = 0;
|
||||||
virtual ProximityType getProximityType(
|
virtual ProximityType getProximityType(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
|
const DicNode *const dicNode, const DicNode *const childDicNode) const = 0;
|
||||||
const DicNode *const childDicNode) const = 0;
|
|
||||||
virtual bool sameAsTyped(const DicTraverseSession *const traverseSession,
|
virtual bool sameAsTyped(const DicTraverseSession *const traverseSession,
|
||||||
const DicNode *const dicNode) const = 0;
|
const DicNode *const dicNode) const = 0;
|
||||||
virtual bool needsToTraverseAllUserInput() const = 0;
|
virtual bool needsToTraverseAllUserInput() const = 0;
|
||||||
|
@ -49,9 +48,8 @@ class Traversal {
|
||||||
virtual bool allowPartialCommit() const = 0;
|
virtual bool allowPartialCommit() const = 0;
|
||||||
virtual int getDefaultExpandDicNodeSize() const = 0;
|
virtual int getDefaultExpandDicNodeSize() const = 0;
|
||||||
virtual int getMaxCacheSize() const = 0;
|
virtual int getMaxCacheSize() const = 0;
|
||||||
virtual bool isPossibleOmissionChildNode(
|
virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
|
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
|
||||||
const DicNode *const dicNode) const = 0;
|
|
||||||
virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0;
|
virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -69,8 +69,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
|
/* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
|
||||||
const CorrectionType correctionType,
|
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession,
|
|
||||||
const DicNode *const parentDicNode, DicNode *const dicNode,
|
const DicNode *const parentDicNode, DicNode *const dicNode,
|
||||||
hash_map_compat<int, int16_t> *const bigramCacheMap) {
|
hash_map_compat<int, int16_t> *const bigramCacheMap) {
|
||||||
const int inputSize = traverseSession->getInputSize();
|
const int inputSize = traverseSession->getInputSize();
|
||||||
|
@ -94,9 +93,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
|
/* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
|
||||||
const CorrectionType correctionType,
|
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
|
const DicNode *const parentDicNode, const DicNode *const dicNode,
|
||||||
const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) {
|
DicNode_InputStateG *const inputStateG) {
|
||||||
switch(correctionType) {
|
switch(correctionType) {
|
||||||
case CT_OMISSION:
|
case CT_OMISSION:
|
||||||
return weighting->getOmissionCost(parentDicNode, dicNode);
|
return weighting->getOmissionCost(parentDicNode, dicNode);
|
||||||
|
|
|
@ -20,11 +20,12 @@
|
||||||
#include "suggest/policyimpl/typing/scoring_params.h"
|
#include "suggest/policyimpl/typing/scoring_params.h"
|
||||||
|
|
||||||
namespace latinime {
|
namespace latinime {
|
||||||
|
|
||||||
const TypingWeighting TypingWeighting::sInstance;
|
const TypingWeighting TypingWeighting::sInstance;
|
||||||
|
|
||||||
ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType,
|
ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType,
|
||||||
const DicTraverseSession *const traverseSession,
|
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
|
||||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
const DicNode *const dicNode) const {
|
||||||
switch (correctionType) {
|
switch (correctionType) {
|
||||||
case CT_MATCH:
|
case CT_MATCH:
|
||||||
if (isProximityDicNode(traverseSession, dicNode)) {
|
if (isProximityDicNode(traverseSession, dicNode)) {
|
||||||
|
|
|
@ -34,8 +34,8 @@ class TypingWeighting : public Weighting {
|
||||||
static const TypingWeighting *getInstance() { return &sInstance; }
|
static const TypingWeighting *getInstance() { return &sInstance; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
float getTerminalSpatialCost(
|
float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
|
const DicNode *const dicNode) const {
|
||||||
float cost = 0.0f;
|
float cost = 0.0f;
|
||||||
if (dicNode->hasMultipleWords()) {
|
if (dicNode->hasMultipleWords()) {
|
||||||
cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
|
cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
|
||||||
|
@ -66,9 +66,8 @@ class TypingWeighting : public Weighting {
|
||||||
return cost;
|
return cost;
|
||||||
}
|
}
|
||||||
|
|
||||||
float getMatchedCost(
|
float getMatchedCost(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
|
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
|
||||||
DicNode_InputStateG *inputStateG) const {
|
|
||||||
const int pointIndex = dicNode->getInputIndex(0);
|
const int pointIndex = dicNode->getInputIndex(0);
|
||||||
// Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on
|
// Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on
|
||||||
// the keyboard (like accented letters)
|
// the keyboard (like accented letters)
|
||||||
|
@ -85,8 +84,8 @@ class TypingWeighting : public Weighting {
|
||||||
return weightedDistance + cost;
|
return weightedDistance + cost;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isProximityDicNode(
|
bool isProximityDicNode(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
|
const DicNode *const dicNode) const {
|
||||||
const int pointIndex = dicNode->getInputIndex(0);
|
const int pointIndex = dicNode->getInputIndex(0);
|
||||||
const int primaryCodePoint = toBaseLowerCase(
|
const int primaryCodePoint = toBaseLowerCase(
|
||||||
traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
|
traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
|
||||||
|
@ -94,9 +93,8 @@ class TypingWeighting : public Weighting {
|
||||||
return primaryCodePoint != dicNodeChar;
|
return primaryCodePoint != dicNodeChar;
|
||||||
}
|
}
|
||||||
|
|
||||||
float getTranspositionCost(
|
float getTranspositionCost(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
|
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
||||||
const DicNode *const dicNode) const {
|
|
||||||
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
|
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
|
||||||
const int prevCodePoint = parentDicNode->getNodeCodePoint();
|
const int prevCodePoint = parentDicNode->getNodeCodePoint();
|
||||||
const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
||||||
|
@ -110,8 +108,7 @@ class TypingWeighting : public Weighting {
|
||||||
return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
|
return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
|
||||||
}
|
}
|
||||||
|
|
||||||
float getInsertionCost(
|
float getInsertionCost(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession,
|
|
||||||
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
||||||
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
|
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
|
||||||
const int prevCodePoint =
|
const int prevCodePoint =
|
||||||
|
@ -137,8 +134,8 @@ class TypingWeighting : public Weighting {
|
||||||
return cost * traverseSession->getMultiWordCostMultiplier();
|
return cost * traverseSession->getMultiWordCostMultiplier();
|
||||||
}
|
}
|
||||||
|
|
||||||
float getNewWordBigramCost(
|
float getNewWordBigramCost(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
|
const DicNode *const dicNode,
|
||||||
hash_map_compat<int, int16_t> *const bigramCacheMap) const {
|
hash_map_compat<int, int16_t> *const bigramCacheMap) const {
|
||||||
return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(),
|
return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(),
|
||||||
dicNode, bigramCacheMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
dicNode, bigramCacheMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
||||||
|
@ -174,8 +171,7 @@ class TypingWeighting : public Weighting {
|
||||||
return ScoringParams::SUBSTITUTION_COST;
|
return ScoringParams::SUBSTITUTION_COST;
|
||||||
}
|
}
|
||||||
|
|
||||||
AK_FORCE_INLINE float getSpaceSubstitutionCost(
|
AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
|
||||||
const DicTraverseSession *const traverseSession,
|
|
||||||
const DicNode *const dicNode) const {
|
const DicNode *const dicNode) const {
|
||||||
const bool isCapitalized = dicNode->isCapitalized();
|
const bool isCapitalized = dicNode->isCapitalized();
|
||||||
const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + (isCapitalized ?
|
const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + (isCapitalized ?
|
||||||
|
|
Loading…
Reference in New Issue