444 lines
16 KiB
C++
444 lines
16 KiB
C++
|
/*
|
||
|
* Copyright (C) 2011 The Android Open Source Project
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#include <assert.h>
|
||
|
#include <stdio.h>
|
||
|
#include <string.h>
|
||
|
|
||
|
#define LOG_TAG "LatinIME: correction.cpp"
|
||
|
|
||
|
#include "correction.h"
|
||
|
#include "proximity_info.h"
|
||
|
|
||
|
namespace latinime {
|
||
|
|
||
|
//////////////////////
|
||
|
// inline functions //
|
||
|
//////////////////////
|
||
|
static const char QUOTE = '\'';
|
||
|
|
||
|
inline bool Correction::isQuote(const unsigned short c) {
|
||
|
const unsigned short userTypedChar = mProximityInfo->getPrimaryCharAt(mInputIndex);
|
||
|
return (c == QUOTE && userTypedChar != QUOTE);
|
||
|
}
|
||
|
|
||
|
////////////////
|
||
|
// Correction //
|
||
|
////////////////
|
||
|
|
||
|
Correction::Correction(const int typedLetterMultiplier, const int fullWordMultiplier)
|
||
|
: TYPED_LETTER_MULTIPLIER(typedLetterMultiplier), FULL_WORD_MULTIPLIER(fullWordMultiplier) {
|
||
|
}
|
||
|
|
||
|
void Correction::initCorrection(const ProximityInfo *pi, const int inputLength,
|
||
|
const int maxDepth) {
|
||
|
mProximityInfo = pi;
|
||
|
mInputLength = inputLength;
|
||
|
mMaxDepth = maxDepth;
|
||
|
mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
|
||
|
mSkippedOutputIndex = -1;
|
||
|
}
|
||
|
|
||
|
void Correction::setCorrectionParams(const int skipPos, const int excessivePos,
|
||
|
const int transposedPos, const int spaceProximityPos, const int missingSpacePos) {
|
||
|
mSkipPos = skipPos;
|
||
|
mExcessivePos = excessivePos;
|
||
|
mTransposedPos = transposedPos;
|
||
|
mSpaceProximityPos = spaceProximityPos;
|
||
|
mMissingSpacePos = missingSpacePos;
|
||
|
}
|
||
|
|
||
|
void Correction::checkState() {
|
||
|
if (DEBUG_DICT) {
|
||
|
int inputCount = 0;
|
||
|
if (mSkipPos >= 0) ++inputCount;
|
||
|
if (mExcessivePos >= 0) ++inputCount;
|
||
|
if (mTransposedPos >= 0) ++inputCount;
|
||
|
// TODO: remove this assert
|
||
|
assert(inputCount <= 1);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
int Correction::getFreqForSplitTwoWords(const int firstFreq, const int secondFreq) {
|
||
|
return Correction::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this);
|
||
|
}
|
||
|
|
||
|
int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLength) {
|
||
|
const int outputIndex = mTerminalOutputIndex;
|
||
|
const int inputIndex = mTerminalInputIndex;
|
||
|
*wordLength = outputIndex + 1;
|
||
|
if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) {
|
||
|
return -1;
|
||
|
}
|
||
|
*word = mWord;
|
||
|
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
|
||
|
: (mInputLength == inputIndex + 1);
|
||
|
return Correction::RankingAlgorithm::calculateFinalFreq(
|
||
|
inputIndex, outputIndex, mMatchedCharCount, freq, sameLength, this);
|
||
|
}
|
||
|
|
||
|
void Correction::initProcessState(const int matchCount, const int inputIndex,
|
||
|
const int outputIndex, const bool traverseAllNodes, const int diffs) {
|
||
|
mMatchedCharCount = matchCount;
|
||
|
mInputIndex = inputIndex;
|
||
|
mOutputIndex = outputIndex;
|
||
|
mTraverseAllNodes = traverseAllNodes;
|
||
|
mDiffs = diffs;
|
||
|
}
|
||
|
|
||
|
void Correction::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex,
|
||
|
bool *traverseAllNodes, int *diffs) {
|
||
|
*matchedCount = mMatchedCharCount;
|
||
|
*inputIndex = mInputIndex;
|
||
|
*outputIndex = mOutputIndex;
|
||
|
*traverseAllNodes = mTraverseAllNodes;
|
||
|
*diffs = mDiffs;
|
||
|
}
|
||
|
|
||
|
void Correction::charMatched() {
|
||
|
++mMatchedCharCount;
|
||
|
}
|
||
|
|
||
|
// TODO: remove
|
||
|
int Correction::getOutputIndex() {
|
||
|
return mOutputIndex;
|
||
|
}
|
||
|
|
||
|
// TODO: remove
|
||
|
int Correction::getInputIndex() {
|
||
|
return mInputIndex;
|
||
|
}
|
||
|
|
||
|
// TODO: remove
|
||
|
bool Correction::needsToTraverseAll() {
|
||
|
return mTraverseAllNodes;
|
||
|
}
|
||
|
|
||
|
void Correction::incrementInputIndex() {
|
||
|
++mInputIndex;
|
||
|
}
|
||
|
|
||
|
void Correction::incrementOutputIndex() {
|
||
|
++mOutputIndex;
|
||
|
}
|
||
|
|
||
|
void Correction::startTraverseAll() {
|
||
|
mTraverseAllNodes = true;
|
||
|
}
|
||
|
|
||
|
bool Correction::needsToPrune() const {
|
||
|
return (mOutputIndex - 1 >= (mTransposedPos >= 0 ? mInputLength - 1 : mMaxDepth)
|
||
|
|| mDiffs > mMaxEditDistance);
|
||
|
}
|
||
|
|
||
|
Correction::CorrectionType Correction::processSkipChar(
|
||
|
const int32_t c, const bool isTerminal) {
|
||
|
mWord[mOutputIndex] = c;
|
||
|
if (needsToTraverseAll() && isTerminal) {
|
||
|
mTerminalInputIndex = mInputIndex;
|
||
|
mTerminalOutputIndex = mOutputIndex;
|
||
|
incrementOutputIndex();
|
||
|
return TRAVERSE_ALL_ON_TERMINAL;
|
||
|
} else {
|
||
|
incrementOutputIndex();
|
||
|
return TRAVERSE_ALL_NOT_ON_TERMINAL;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Correction::CorrectionType Correction::processCharAndCalcState(
|
||
|
const int32_t c, const bool isTerminal) {
|
||
|
CorrectionType currentStateType = NOT_ON_TERMINAL;
|
||
|
// This has to be done for each virtual char (this forwards the "inputIndex" which
|
||
|
// is the index in the user-inputted chars, as read by proximity chars.
|
||
|
if (mExcessivePos == mOutputIndex && mInputIndex < mInputLength - 1) {
|
||
|
incrementInputIndex();
|
||
|
}
|
||
|
|
||
|
bool skip = false;
|
||
|
if (mSkipPos >= 0) {
|
||
|
skip = mSkipPos == mOutputIndex;
|
||
|
}
|
||
|
|
||
|
if (mTraverseAllNodes || isQuote(c)) {
|
||
|
return processSkipChar(c, isTerminal);
|
||
|
} else {
|
||
|
int inputIndexForProximity = mInputIndex;
|
||
|
|
||
|
if (mTransposedPos >= 0) {
|
||
|
if (mInputIndex == mTransposedPos) {
|
||
|
++inputIndexForProximity;
|
||
|
}
|
||
|
if (mInputIndex == (mTransposedPos + 1)) {
|
||
|
--inputIndexForProximity;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const bool checkProximityChars =
|
||
|
!(mSkipPos >= 0 || mExcessivePos >= 0 || mTransposedPos >= 0);
|
||
|
int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
|
||
|
inputIndexForProximity, c, checkProximityChars);
|
||
|
|
||
|
const bool unrelated = ProximityInfo::UNRELATED_CHAR == matchedProximityCharId;
|
||
|
if (unrelated) {
|
||
|
if (skip) {
|
||
|
// Skip this letter and continue deeper
|
||
|
mSkippedOutputIndex = mOutputIndex;
|
||
|
return processSkipChar(c, isTerminal);
|
||
|
} else {
|
||
|
return UNRELATED;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// No need to skip. Finish traversing and increment skipPos.
|
||
|
// TODO: Remove this?
|
||
|
if (skip) {
|
||
|
mWord[mOutputIndex] = c;
|
||
|
incrementOutputIndex();
|
||
|
return TRAVERSE_ALL_NOT_ON_TERMINAL;
|
||
|
}
|
||
|
|
||
|
mWord[mOutputIndex] = c;
|
||
|
// If inputIndex is greater than mInputLength, that means there is no
|
||
|
// proximity chars. So, we don't need to check proximity.
|
||
|
if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
|
||
|
charMatched();
|
||
|
}
|
||
|
|
||
|
if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) {
|
||
|
incrementDiffs();
|
||
|
}
|
||
|
|
||
|
const bool isSameAsUserTypedLength = mInputLength
|
||
|
== getInputIndex() + 1
|
||
|
|| (mExcessivePos == mInputLength - 1
|
||
|
&& getInputIndex() == mInputLength - 2);
|
||
|
if (isSameAsUserTypedLength && isTerminal) {
|
||
|
mTerminalInputIndex = mInputIndex;
|
||
|
mTerminalOutputIndex = mOutputIndex;
|
||
|
currentStateType = ON_TERMINAL;
|
||
|
}
|
||
|
// Start traversing all nodes after the index exceeds the user typed length
|
||
|
if (isSameAsUserTypedLength) {
|
||
|
startTraverseAll();
|
||
|
}
|
||
|
|
||
|
// Finally, we are ready to go to the next character, the next "virtual node".
|
||
|
// We should advance the input index.
|
||
|
// We do this in this branch of the 'if traverseAllNodes' because we are still matching
|
||
|
// characters to input; the other branch is not matching them but searching for
|
||
|
// completions, this is why it does not have to do it.
|
||
|
incrementInputIndex();
|
||
|
}
|
||
|
|
||
|
// Also, the next char is one "virtual node" depth more than this char.
|
||
|
incrementOutputIndex();
|
||
|
|
||
|
return currentStateType;
|
||
|
}
|
||
|
|
||
|
Correction::~Correction() {
|
||
|
}
|
||
|
|
||
|
/////////////////////////
|
||
|
// static inline utils //
|
||
|
/////////////////////////
|
||
|
|
||
|
static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
|
||
|
static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
|
||
|
return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
|
||
|
}
|
||
|
|
||
|
static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
|
||
|
inline static void multiplyIntCapped(const int multiplier, int *base) {
|
||
|
const int temp = *base;
|
||
|
if (temp != S_INT_MAX) {
|
||
|
// Branch if multiplier == 2 for the optimization
|
||
|
if (multiplier == 2) {
|
||
|
*base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
|
||
|
} else {
|
||
|
const int tempRetval = temp * multiplier;
|
||
|
*base = tempRetval >= temp ? tempRetval : S_INT_MAX;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
inline static int powerIntCapped(const int base, const int n) {
|
||
|
if (n == 0) return 1;
|
||
|
if (base == 2) {
|
||
|
return n < 31 ? 1 << n : S_INT_MAX;
|
||
|
} else {
|
||
|
int ret = base;
|
||
|
for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
|
||
|
return ret;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
inline static void multiplyRate(const int rate, int *freq) {
|
||
|
if (*freq != S_INT_MAX) {
|
||
|
if (*freq > 1000000) {
|
||
|
*freq /= 100;
|
||
|
multiplyIntCapped(rate, freq);
|
||
|
} else {
|
||
|
multiplyIntCapped(rate, freq);
|
||
|
*freq /= 100;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//////////////////////
|
||
|
// RankingAlgorithm //
|
||
|
//////////////////////
|
||
|
|
||
|
int Correction::RankingAlgorithm::calculateFinalFreq(
|
||
|
const int inputIndex, const int outputIndex,
|
||
|
const int matchCount, const int freq, const bool sameLength,
|
||
|
const Correction* correction) {
|
||
|
const int skipPos = correction->getSkipPos();
|
||
|
const int excessivePos = correction->getExcessivePos();
|
||
|
const int transposedPos = correction->getTransposedPos();
|
||
|
const int inputLength = correction->mInputLength;
|
||
|
const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER;
|
||
|
const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER;
|
||
|
const ProximityInfo *proximityInfo = correction->mProximityInfo;
|
||
|
const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
|
||
|
|
||
|
// TODO: Demote by edit distance
|
||
|
int finalFreq = freq * matchWeight;
|
||
|
if (skipPos >= 0) {
|
||
|
if (inputLength >= 2) {
|
||
|
const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE
|
||
|
* (10 * inputLength - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X)
|
||
|
/ (10 * inputLength
|
||
|
- WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X + 10);
|
||
|
if (DEBUG_DICT_FULL) {
|
||
|
LOGI("Demotion rate for missing character is %d.", demotionRate);
|
||
|
}
|
||
|
multiplyRate(demotionRate, &finalFreq);
|
||
|
} else {
|
||
|
finalFreq = 0;
|
||
|
}
|
||
|
}
|
||
|
if (transposedPos >= 0) multiplyRate(
|
||
|
WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE, &finalFreq);
|
||
|
if (excessivePos >= 0) {
|
||
|
multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_DEMOTION_RATE, &finalFreq);
|
||
|
if (!proximityInfo->existsAdjacentProximityChars(inputIndex)) {
|
||
|
// If an excessive character is not adjacent to the left char or the right char,
|
||
|
// we will demote this word.
|
||
|
multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_OUT_OF_PROXIMITY_DEMOTION_RATE, &finalFreq);
|
||
|
}
|
||
|
}
|
||
|
int lengthFreq = typedLetterMultiplier;
|
||
|
multiplyIntCapped(powerIntCapped(typedLetterMultiplier, outputIndex), &lengthFreq);
|
||
|
if ((outputIndex + 1) == matchCount) {
|
||
|
// Full exact match
|
||
|
if (outputIndex > 1) {
|
||
|
if (DEBUG_DICT) {
|
||
|
LOGI("Found full matched word.");
|
||
|
}
|
||
|
multiplyRate(FULL_MATCHED_WORDS_PROMOTION_RATE, &finalFreq);
|
||
|
}
|
||
|
if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0) {
|
||
|
finalFreq = capped255MultForFullMatchAccentsOrCapitalizationDifference(finalFreq);
|
||
|
}
|
||
|
} else if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0
|
||
|
&& outputIndex > 0) {
|
||
|
// A word with proximity corrections
|
||
|
if (DEBUG_DICT) {
|
||
|
LOGI("Found one proximity correction.");
|
||
|
}
|
||
|
multiplyIntCapped(typedLetterMultiplier, &finalFreq);
|
||
|
multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq);
|
||
|
}
|
||
|
if (DEBUG_DICT) {
|
||
|
LOGI("calc: %d, %d", outputIndex, sameLength);
|
||
|
}
|
||
|
if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq);
|
||
|
return finalFreq;
|
||
|
}
|
||
|
|
||
|
int Correction::RankingAlgorithm::calcFreqForSplitTwoWords(
|
||
|
const int firstFreq, const int secondFreq, const Correction* correction) {
|
||
|
const int spaceProximityPos = correction->mSpaceProximityPos;
|
||
|
const int missingSpacePos = correction->mMissingSpacePos;
|
||
|
if (DEBUG_DICT) {
|
||
|
int inputCount = 0;
|
||
|
if (spaceProximityPos >= 0) ++inputCount;
|
||
|
if (missingSpacePos >= 0) ++inputCount;
|
||
|
assert(inputCount <= 1);
|
||
|
}
|
||
|
const bool isSpaceProximity = spaceProximityPos >= 0;
|
||
|
const int inputLength = correction->mInputLength;
|
||
|
const int firstWordLength = isSpaceProximity ? spaceProximityPos : missingSpacePos;
|
||
|
const int secondWordLength = isSpaceProximity
|
||
|
? (inputLength - spaceProximityPos - 1)
|
||
|
: (inputLength - missingSpacePos);
|
||
|
const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER;
|
||
|
|
||
|
if (firstWordLength == 0 || secondWordLength == 0) {
|
||
|
return 0;
|
||
|
}
|
||
|
const int firstDemotionRate = 100 - 100 / (firstWordLength + 1);
|
||
|
int tempFirstFreq = firstFreq;
|
||
|
multiplyRate(firstDemotionRate, &tempFirstFreq);
|
||
|
|
||
|
const int secondDemotionRate = 100 - 100 / (secondWordLength + 1);
|
||
|
int tempSecondFreq = secondFreq;
|
||
|
multiplyRate(secondDemotionRate, &tempSecondFreq);
|
||
|
|
||
|
const int totalLength = firstWordLength + secondWordLength;
|
||
|
|
||
|
// Promote pairFreq with multiplying by 2, because the word length is the same as the typed
|
||
|
// length.
|
||
|
int totalFreq = tempFirstFreq + tempSecondFreq;
|
||
|
|
||
|
// This is a workaround to try offsetting the not-enough-demotion which will be done in
|
||
|
// calcNormalizedScore in Utils.java.
|
||
|
// In calcNormalizedScore the score will be demoted by (1 - 1 / length)
|
||
|
// but we demoted only (1 - 1 / (length + 1)) so we will additionally adjust freq by
|
||
|
// (1 - 1 / length) / (1 - 1 / (length + 1)) = (1 - 1 / (length * length))
|
||
|
const int normalizedScoreNotEnoughDemotionAdjustment = 100 - 100 / (totalLength * totalLength);
|
||
|
multiplyRate(normalizedScoreNotEnoughDemotionAdjustment, &totalFreq);
|
||
|
|
||
|
// At this moment, totalFreq is calculated by the following formula:
|
||
|
// (firstFreq * (1 - 1 / (firstWordLength + 1)) + secondFreq * (1 - 1 / (secondWordLength + 1)))
|
||
|
// * (1 - 1 / totalLength) / (1 - 1 / (totalLength + 1))
|
||
|
|
||
|
multiplyIntCapped(powerIntCapped(typedLetterMultiplier, totalLength), &totalFreq);
|
||
|
|
||
|
// This is another workaround to offset the demotion which will be done in
|
||
|
// calcNormalizedScore in Utils.java.
|
||
|
// In calcNormalizedScore the score will be demoted by (1 - 1 / length) so we have to promote
|
||
|
// the same amount because we already have adjusted the synthetic freq of this "missing or
|
||
|
// mistyped space" suggestion candidate above in this method.
|
||
|
const int normalizedScoreDemotionRateOffset = (100 + 100 / totalLength);
|
||
|
multiplyRate(normalizedScoreDemotionRateOffset, &totalFreq);
|
||
|
|
||
|
if (isSpaceProximity) {
|
||
|
// A word pair with one space proximity correction
|
||
|
if (DEBUG_DICT) {
|
||
|
LOGI("Found a word pair with space proximity correction.");
|
||
|
}
|
||
|
multiplyIntCapped(typedLetterMultiplier, &totalFreq);
|
||
|
multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &totalFreq);
|
||
|
}
|
||
|
|
||
|
multiplyRate(WORDS_WITH_MISSING_SPACE_CHARACTER_DEMOTION_RATE, &totalFreq);
|
||
|
return totalFreq;
|
||
|
}
|
||
|
|
||
|
} // namespace latinime
|