diff --git a/java/src/com/android/inputmethod/latin/BinaryDictionary.java b/java/src/com/android/inputmethod/latin/BinaryDictionary.java index 2ae54348a..35c4681d9 100644 --- a/java/src/com/android/inputmethod/latin/BinaryDictionary.java +++ b/java/src/com/android/inputmethod/latin/BinaryDictionary.java @@ -188,7 +188,8 @@ public final class BinaryDictionary extends Dictionary { int[][] prevWordCodePointArrays, boolean[] isBeginningOfSentenceArray, int prevWordCount, int[] outputSuggestionCount, int[] outputCodePoints, int[] outputScores, int[] outputIndices, int[] outputTypes, - int[] outputAutoCommitFirstWordConfidence, float[] inOutLanguageWeight); + int[] outputAutoCommitFirstWordConfidence, + float[] inOutWeightOfLangModelVsSpatialModel); private static native boolean addUnigramEntryNative(long dict, int[] word, int probability, int[] shortcutTarget, int shortcutProbability, boolean isBeginningOfSentence, boolean isNotAWord, boolean isBlacklisted, int timestamp); @@ -256,7 +257,8 @@ public final class BinaryDictionary extends Dictionary { public ArrayList getSuggestions(final WordComposer composer, final PrevWordsInfo prevWordsInfo, final ProximityInfo proximityInfo, final SettingsValuesForSuggestion settingsValuesForSuggestion, - final int sessionId, final float[] inOutLanguageWeight) { + final int sessionId, final float weightForLocale, + final float[] inOutWeightOfLangModelVsSpatialModel) { if (!isValidDictionary()) { return null; } @@ -284,10 +286,12 @@ public final class BinaryDictionary extends Dictionary { settingsValuesForSuggestion.mSpaceAwareGestureEnabled); session.mNativeSuggestOptions.setAdditionalFeaturesOptions( settingsValuesForSuggestion.mAdditionalFeaturesSettingValues); - if (inOutLanguageWeight != null) { - session.mInputOutputLanguageWeight[0] = inOutLanguageWeight[0]; + if (inOutWeightOfLangModelVsSpatialModel != null) { + session.mInputOutputWeightOfLangModelVsSpatialModel[0] = + inOutWeightOfLangModelVsSpatialModel[0]; } else { - session.mInputOutputLanguageWeight[0] = Dictionary.NOT_A_LANGUAGE_WEIGHT; + session.mInputOutputWeightOfLangModelVsSpatialModel[0] = + Dictionary.NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL; } // TOOD: Pass multiple previous words information for n-gram. getSuggestionsNative(mNativeDict, proximityInfo.getNativeProximityInfo(), @@ -298,9 +302,11 @@ public final class BinaryDictionary extends Dictionary { session.mIsBeginningOfSentenceArray, prevWordsInfo.getPrevWordCount(), session.mOutputSuggestionCount, session.mOutputCodePoints, session.mOutputScores, session.mSpaceIndices, session.mOutputTypes, - session.mOutputAutoCommitFirstWordConfidence, session.mInputOutputLanguageWeight); - if (inOutLanguageWeight != null) { - inOutLanguageWeight[0] = session.mInputOutputLanguageWeight[0]; + session.mOutputAutoCommitFirstWordConfidence, + session.mInputOutputWeightOfLangModelVsSpatialModel); + if (inOutWeightOfLangModelVsSpatialModel != null) { + inOutWeightOfLangModelVsSpatialModel[0] = + session.mInputOutputWeightOfLangModelVsSpatialModel[0]; } final int count = session.mOutputSuggestionCount[0]; final ArrayList suggestions = new ArrayList<>(); @@ -314,7 +320,8 @@ public final class BinaryDictionary extends Dictionary { if (len > 0) { suggestions.add(new SuggestedWordInfo( new String(session.mOutputCodePoints, start, len), - session.mOutputScores[j], session.mOutputTypes[j], this /* sourceDict */, + (int)(session.mOutputScores[j] * weightForLocale), session.mOutputTypes[j], + this /* sourceDict */, session.mSpaceIndices[j] /* indexOfTouchPointOfSecondWord */, session.mOutputAutoCommitFirstWordConfidence[0])); } diff --git a/java/src/com/android/inputmethod/latin/DicTraverseSession.java b/java/src/com/android/inputmethod/latin/DicTraverseSession.java index b341f623e..2751c1250 100644 --- a/java/src/com/android/inputmethod/latin/DicTraverseSession.java +++ b/java/src/com/android/inputmethod/latin/DicTraverseSession.java @@ -40,7 +40,7 @@ public final class DicTraverseSession { public final int[] mOutputTypes = new int[MAX_RESULTS]; // Only one result is ever used public final int[] mOutputAutoCommitFirstWordConfidence = new int[1]; - public final float[] mInputOutputLanguageWeight = new float[1]; + public final float[] mInputOutputWeightOfLangModelVsSpatialModel = new float[1]; public final NativeSuggestOptions mNativeSuggestOptions = new NativeSuggestOptions(); diff --git a/java/src/com/android/inputmethod/latin/Dictionary.java b/java/src/com/android/inputmethod/latin/Dictionary.java index cad9ee7d8..b58a52b41 100644 --- a/java/src/com/android/inputmethod/latin/Dictionary.java +++ b/java/src/com/android/inputmethod/latin/Dictionary.java @@ -31,7 +31,7 @@ import java.util.HashSet; */ public abstract class Dictionary { public static final int NOT_A_PROBABILITY = -1; - public static final float NOT_A_LANGUAGE_WEIGHT = -1.0f; + public static final float NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL = -1.0f; // The following types do not actually come from real dictionary instances, so we create // corresponding instances. @@ -88,15 +88,18 @@ public abstract class Dictionary { * @param proximityInfo the object for key proximity. May be ignored by some implementations. * @param settingsValuesForSuggestion the settings values used for the suggestion. * @param sessionId the session id. - * @param inOutLanguageWeight the language weight used for generating suggestions. - * inOutLanguageWeight is a float array that has only one element. This can be updated when the - * different language weight is used. + * @param weightForLocale the weight given to this locale, to multiply the output scores for + * multilingual input. + * @param inOutWeightOfLangModelVsSpatialModel the weight of the language model as a ratio of + * the spatial model, used for generating suggestions. inOutWeightOfLangModelVsSpatialModel is + * a float array that has only one element. This can be updated when a different value is used. * @return the list of suggestions (possibly null if none) */ abstract public ArrayList getSuggestions(final WordComposer composer, final PrevWordsInfo prevWordsInfo, final ProximityInfo proximityInfo, final SettingsValuesForSuggestion settingsValuesForSuggestion, - final int sessionId, final float[] inOutLanguageWeight); + final int sessionId, final float weightForLocale, + final float[] inOutWeightOfLangModelVsSpatialModel); /** * Checks if the given word has to be treated as a valid word. Please note that some @@ -190,7 +193,8 @@ public abstract class Dictionary { public ArrayList getSuggestions(final WordComposer composer, final PrevWordsInfo prevWordsInfo, final ProximityInfo proximityInfo, final SettingsValuesForSuggestion settingsValuesForSuggestion, - final int sessionId, final float[] inOutLanguageWeight) { + final int sessionId, final float weightForLocale, + final float[] inOutWeightOfLangModelVsSpatialModel) { return null; } diff --git a/java/src/com/android/inputmethod/latin/DictionaryCollection.java b/java/src/com/android/inputmethod/latin/DictionaryCollection.java index ca5e93714..b26b37817 100644 --- a/java/src/com/android/inputmethod/latin/DictionaryCollection.java +++ b/java/src/com/android/inputmethod/latin/DictionaryCollection.java @@ -62,20 +62,21 @@ public final class DictionaryCollection extends Dictionary { public ArrayList getSuggestions(final WordComposer composer, final PrevWordsInfo prevWordsInfo, final ProximityInfo proximityInfo, final SettingsValuesForSuggestion settingsValuesForSuggestion, - final int sessionId, final float[] inOutLanguageWeight) { + final int sessionId, final float weightForLocale, + final float[] inOutWeightOfLangModelVsSpatialModel) { final CopyOnWriteArrayList dictionaries = mDictionaries; if (dictionaries.isEmpty()) return null; // To avoid creating unnecessary objects, we get the list out of the first // dictionary and add the rest to it if not null, hence the get(0) ArrayList suggestions = dictionaries.get(0).getSuggestions(composer, prevWordsInfo, proximityInfo, settingsValuesForSuggestion, sessionId, - inOutLanguageWeight); + weightForLocale, inOutWeightOfLangModelVsSpatialModel); if (null == suggestions) suggestions = new ArrayList<>(); final int length = dictionaries.size(); for (int i = 1; i < length; ++ i) { final ArrayList sugg = dictionaries.get(i).getSuggestions(composer, prevWordsInfo, proximityInfo, settingsValuesForSuggestion, sessionId, - inOutLanguageWeight); + weightForLocale, inOutWeightOfLangModelVsSpatialModel); if (null != sugg) suggestions.addAll(sugg); } return suggestions; diff --git a/java/src/com/android/inputmethod/latin/DictionaryFacilitator.java b/java/src/com/android/inputmethod/latin/DictionaryFacilitator.java index 0f09daf86..9af15c182 100644 --- a/java/src/com/android/inputmethod/latin/DictionaryFacilitator.java +++ b/java/src/com/android/inputmethod/latin/DictionaryFacilitator.java @@ -104,6 +104,7 @@ public class DictionaryFacilitator { private static class DictionaryGroup { public final Locale mLocale; private Dictionary mMainDict; + public float mWeightForLocale = 1.0f; public final ConcurrentHashMap mSubDictMap = new ConcurrentHashMap<>(); @@ -598,14 +599,16 @@ public class DictionaryFacilitator { final SuggestionResults suggestionResults = new SuggestionResults( SuggestedWords.MAX_SUGGESTIONS, prevWordsInfo.mPrevWordsInfo[0].mIsBeginningOfSentence); - final float[] languageWeight = new float[] { Dictionary.NOT_A_LANGUAGE_WEIGHT }; + final float[] weightOfLangModelVsSpatialModel = + new float[] { Dictionary.NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL }; for (final DictionaryGroup dictionaryGroup : dictionaryGroups) { for (final String dictType : DICT_TYPES_ORDERED_TO_GET_SUGGESTIONS) { final Dictionary dictionary = dictionaryGroup.getDict(dictType); if (null == dictionary) continue; final ArrayList dictionarySuggestions = dictionary.getSuggestions(composer, prevWordsInfo, proximityInfo, - settingsValuesForSuggestion, sessionId, languageWeight); + settingsValuesForSuggestion, sessionId, + dictionaryGroup.mWeightForLocale, weightOfLangModelVsSpatialModel); if (null == dictionarySuggestions) continue; suggestionResults.addAll(dictionarySuggestions); if (null != suggestionResults.mRawSuggestions) { diff --git a/java/src/com/android/inputmethod/latin/ExpandableBinaryDictionary.java b/java/src/com/android/inputmethod/latin/ExpandableBinaryDictionary.java index 671ba6714..ad967c133 100644 --- a/java/src/com/android/inputmethod/latin/ExpandableBinaryDictionary.java +++ b/java/src/com/android/inputmethod/latin/ExpandableBinaryDictionary.java @@ -435,7 +435,7 @@ abstract public class ExpandableBinaryDictionary extends Dictionary { public ArrayList getSuggestions(final WordComposer composer, final PrevWordsInfo prevWordsInfo, final ProximityInfo proximityInfo, final SettingsValuesForSuggestion settingsValuesForSuggestion, final int sessionId, - final float[] inOutLanguageWeight) { + final float weightForLocale, final float[] inOutWeightOfLangModelVsSpatialModel) { reloadDictionaryIfRequired(); boolean lockAcquired = false; try { @@ -447,7 +447,8 @@ abstract public class ExpandableBinaryDictionary extends Dictionary { } final ArrayList suggestions = mBinaryDictionary.getSuggestions(composer, prevWordsInfo, proximityInfo, - settingsValuesForSuggestion, sessionId, inOutLanguageWeight); + settingsValuesForSuggestion, sessionId, weightForLocale, + inOutWeightOfLangModelVsSpatialModel); if (mBinaryDictionary.isCorrupted()) { Log.i(TAG, "Dictionary (" + mDictName +") is corrupted. " + "Remove and regenerate it."); diff --git a/java/src/com/android/inputmethod/latin/ReadOnlyBinaryDictionary.java b/java/src/com/android/inputmethod/latin/ReadOnlyBinaryDictionary.java index ecf25c28b..827367bb4 100644 --- a/java/src/com/android/inputmethod/latin/ReadOnlyBinaryDictionary.java +++ b/java/src/com/android/inputmethod/latin/ReadOnlyBinaryDictionary.java @@ -53,11 +53,13 @@ public final class ReadOnlyBinaryDictionary extends Dictionary { public ArrayList getSuggestions(final WordComposer composer, final PrevWordsInfo prevWordsInfo, final ProximityInfo proximityInfo, final SettingsValuesForSuggestion settingsValuesForSuggestion, - final int sessionId, final float[] inOutLanguageWeight) { + final int sessionId, final float weightForLocale, + final float[] inOutWeightOfLangModelVsSpatialModel) { if (mLock.readLock().tryLock()) { try { return mBinaryDictionary.getSuggestions(composer, prevWordsInfo, proximityInfo, - settingsValuesForSuggestion, sessionId, inOutLanguageWeight); + settingsValuesForSuggestion, sessionId, weightForLocale, + inOutWeightOfLangModelVsSpatialModel); } finally { mLock.readLock().unlock(); } diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index e65dc4c06..688ce44be 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -182,7 +182,8 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, jobjectArray prevWordCodePointArrays, jbooleanArray isBeginningOfSentenceArray, jint prevWordCount, jintArray outSuggestionCount, jintArray outCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, - jintArray outAutoCommitFirstWordConfidenceArray, jfloatArray inOutLanguageWeight) { + jintArray outAutoCommitFirstWordConfidenceArray, + jfloatArray inOutWeightOfLangModelVsSpatialModel) { Dictionary *dictionary = reinterpret_cast(dict); // Assign 0 to outSuggestionCount here in case of returning earlier in this method. JniDataUtils::putIntToArray(env, outSuggestionCount, 0 /* index */, 0); @@ -237,8 +238,9 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, ASSERT(false); return; } - float languageWeight; - env->GetFloatArrayRegion(inOutLanguageWeight, 0, 1 /* len */, &languageWeight); + float weightOfLangModelVsSpatialModel; + env->GetFloatArrayRegion(inOutWeightOfLangModelVsSpatialModel, 0, 1 /* len */, + &weightOfLangModelVsSpatialModel); SuggestionResults suggestionResults(MAX_RESULTS); const PrevWordsInfo prevWordsInfo = JniDataUtils::constructPrevWordsInfo(env, prevWordCodePointArrays, isBeginningOfSentenceArray, prevWordCount); @@ -246,13 +248,13 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, // TODO: Use SuggestionResults to return suggestions. dictionary->getSuggestions(pInfo, traverseSession, xCoordinates, yCoordinates, times, pointerIds, inputCodePoints, inputSize, &prevWordsInfo, - &givenSuggestOptions, languageWeight, &suggestionResults); + &givenSuggestOptions, weightOfLangModelVsSpatialModel, &suggestionResults); } else { dictionary->getPredictions(&prevWordsInfo, &suggestionResults); } suggestionResults.outputSuggestions(env, outSuggestionCount, outCodePointsArray, outScoresArray, outSpaceIndicesArray, outTypesArray, - outAutoCommitFirstWordConfidenceArray, inOutLanguageWeight); + outAutoCommitFirstWordConfidenceArray, inOutWeightOfLangModelVsSpatialModel); } static jint latinime_BinaryDictionary_getProbability(JNIEnv *env, jclass clazz, jlong dict, diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 56d8bbb72..e55c9eb8a 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -301,7 +301,7 @@ static inline void prof_out(void) { #define NOT_A_DICT_POS (S_INT_MIN) #define NOT_A_WORD_ID (S_INT_MIN) #define NOT_A_TIMESTAMP (-1) -#define NOT_A_LANGUAGE_WEIGHT (-1.0f) +#define NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL (-1.0f) // A special value to mean the first word confidence makes no sense in this case, // e.g. this is not a multi-word suggestion. diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index ec61783cb..5214077dc 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -295,8 +295,9 @@ class DicNode { } // Used to prune nodes - float getCompoundDistance(const float languageWeight) const { - return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); + float getCompoundDistance(const float weightOfLangModelVsSpatialModel) const { + return mDicNodeState.mDicNodeStateScoring.getCompoundDistance( + weightOfLangModelVsSpatialModel); } AK_FORCE_INLINE const int *getOutputWordBuf() const { diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h index c19d48eb9..3a54c2599 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h @@ -103,8 +103,10 @@ class DicNodeStateScoring { return getCompoundDistance(1.0f); } - float getCompoundDistance(const float languageWeight) const { - return mSpatialDistance + mLanguageDistance * languageWeight; + float getCompoundDistance( + const float weightOfLangModelVsSpatialModel) const { + return mSpatialDistance + + mLanguageDistance * weightOfLangModelVsSpatialModel; } float getNormalizedCompoundDistance() const { diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index f9f36ce44..e4084b0f5 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -47,14 +47,14 @@ Dictionary::Dictionary(JNIEnv *env, DictionaryStructureWithBufferPolicy::Structu void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, const PrevWordsInfo *const prevWordsInfo, - const SuggestOptions *const suggestOptions, const float languageWeight, + const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); traverseSession->init(this, prevWordsInfo, suggestOptions); const auto &suggest = suggestOptions->isGesture() ? mGestureSuggest : mTypingSuggest; suggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, - languageWeight, outSuggestionResults); + weightOfLangModelVsSpatialModel, outSuggestionResults); if (DEBUG_DICT) { outSuggestionResults->dumpSuggestions(); } diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index f6482ab78..324e3504a 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -66,7 +66,7 @@ class Dictionary { void getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, const PrevWordsInfo *const prevWordsInfo, - const SuggestOptions *const suggestOptions, const float languageWeight, + const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const; void getPredictions(const PrevWordsInfo *const prevWordsInfo, diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index 9e75cace4..ce3684a1c 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -32,9 +32,11 @@ class Scoring { const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, const bool boostExactMatches) const = 0; virtual void getMostProbableString(const DicTraverseSession *const traverseSession, - const float languageWeight, SuggestionResults *const outSuggestionResults) const = 0; - virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, - DicNode *const terminals, const int size) const = 0; + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) const = 0; + virtual float getAdjustedWeightOfLangModelVsSpatialModel( + DicTraverseSession *const traverseSession, DicNode *const terminals, + const int size) const = 0; virtual float getDoubleLetterDemotionDistanceCost( const DicNode *const terminalDicNode) const = 0; virtual bool autoCorrectsToMultiWordSuggestionIfTop() const = 0; diff --git a/native/jni/src/suggest/core/result/suggestion_results.cpp b/native/jni/src/suggest/core/result/suggestion_results.cpp index 4c10bd08a..3756d1092 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.cpp +++ b/native/jni/src/suggest/core/result/suggestion_results.cpp @@ -23,7 +23,7 @@ namespace latinime { void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outputCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray, - jfloatArray outLanguageWeight) { + jfloatArray outWeightOfLangModelVsSpatialModel) { int outputIndex = 0; while (!mSuggestedWords.empty()) { const SuggestedWord &suggestedWord = mSuggestedWords.top(); @@ -44,7 +44,8 @@ void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCo mSuggestedWords.pop(); } JniDataUtils::putIntToArray(env, outSuggestionCount, 0 /* index */, outputIndex); - JniDataUtils::putFloatToArray(env, outLanguageWeight, 0 /* index */, mLanguageWeight); + JniDataUtils::putFloatToArray(env, outWeightOfLangModelVsSpatialModel, 0 /* index */, + mWeightOfLangModelVsSpatialModel); } void SuggestionResults::addPrediction(const int *const codePoints, const int codePointCount, @@ -89,7 +90,7 @@ void SuggestionResults::getSortedScores(int *const outScores) const { } void SuggestionResults::dumpSuggestions() const { - AKLOGE("language weight: %f", mLanguageWeight); + AKLOGE("weight of language model vs spatial model: %f", mWeightOfLangModelVsSpatialModel); std::vector suggestedWords; auto copyOfSuggestedWords = mSuggestedWords; while (!copyOfSuggestedWords.empty()) { diff --git a/native/jni/src/suggest/core/result/suggestion_results.h b/native/jni/src/suggest/core/result/suggestion_results.h index 8e845e2d3..738c78a9f 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.h +++ b/native/jni/src/suggest/core/result/suggestion_results.h @@ -29,13 +29,15 @@ namespace latinime { class SuggestionResults { public: explicit SuggestionResults(const int maxSuggestionCount) - : mMaxSuggestionCount(maxSuggestionCount), mLanguageWeight(NOT_A_LANGUAGE_WEIGHT), + : mMaxSuggestionCount(maxSuggestionCount), + mWeightOfLangModelVsSpatialModel(NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL), mSuggestedWords() {} // Returns suggestion count. void outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, - jintArray outAutoCommitFirstWordConfidenceArray, jfloatArray outLanguageWeight); + jintArray outAutoCommitFirstWordConfidenceArray, + jfloatArray outWeightOfLangModelVsSpatialModel); void addPrediction(const int *const codePoints, const int codePointCount, const int score); void addSuggestion(const int *const codePoints, const int codePointCount, const int score, const int type, const int indexToPartialCommit, @@ -43,8 +45,8 @@ class SuggestionResults { void getSortedScores(int *const outScores) const; void dumpSuggestions() const; - void setLanguageWeight(const float languageWeight) { - mLanguageWeight = languageWeight; + void setWeightOfLangModelVsSpatialModel(const float weightOfLangModelVsSpatialModel) { + mWeightOfLangModelVsSpatialModel = weightOfLangModelVsSpatialModel; } int getSuggestionCount() const { @@ -55,7 +57,7 @@ class SuggestionResults { DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestionResults); const int mMaxSuggestionCount; - float mLanguageWeight; + float mWeightOfLangModelVsSpatialModel; std::priority_queue< SuggestedWord, std::vector, SuggestedWord::Comparator> mSuggestedWords; }; diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index 6e0193772..3283f6deb 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -34,7 +34,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; /* static */ void SuggestionsOutputUtils::outputSuggestions( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, - const float languageWeight, SuggestionResults *const outSuggestionResults) { + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) { #if DEBUG_EVALUATE_MOST_PROBABLE_STRING const int terminalSize = 0; #else @@ -44,12 +45,15 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; for (int index = terminalSize - 1; index >= 0; --index) { traverseSession->getDicTraverseCache()->popTerminal(&terminals[index]); } - // Compute a language weight when an invalid language weight is passed. - // NOT_A_LANGUAGE_WEIGHT (-1) is assumed as an invalid language weight. - const float languageWeightToOutputSuggestions = (languageWeight < 0.0f) ? - scoringPolicy->getAdjustedLanguageWeight( - traverseSession, terminals.data(), terminalSize) : languageWeight; - outSuggestionResults->setLanguageWeight(languageWeightToOutputSuggestions); + // Compute a weight of language model when an invalid weight is passed. + // NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL (-1) is taken as an invalid value. + const float weightOfLangModelVsSpatialModelToOutputSuggestions = + (weightOfLangModelVsSpatialModel < 0.0f) + ? scoringPolicy->getAdjustedWeightOfLangModelVsSpatialModel(traverseSession, + terminals.data(), terminalSize) + : weightOfLangModelVsSpatialModel; + outSuggestionResults->setWeightOfLangModelVsSpatialModel( + weightOfLangModelVsSpatialModelToOutputSuggestions); // Force autocorrection for obvious long multi-word suggestions when the top suggestion is // a long multiple words suggestion. // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. @@ -65,16 +69,16 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; // Output suggestion results here for (auto &terminalDicNode : terminals) { outputSuggestionsOfDicNode(scoringPolicy, traverseSession, &terminalDicNode, - languageWeightToOutputSuggestions, boostExactMatches, forceCommitMultiWords, - outputSecondWordFirstLetterInputIndex, outSuggestionResults); + weightOfLangModelVsSpatialModelToOutputSuggestions, boostExactMatches, + forceCommitMultiWords, outputSecondWordFirstLetterInputIndex, outSuggestionResults); } - scoringPolicy->getMostProbableString(traverseSession, languageWeightToOutputSuggestions, - outSuggestionResults); + scoringPolicy->getMostProbableString(traverseSession, + weightOfLangModelVsSpatialModelToOutputSuggestions, outSuggestionResults); } /* static */ void SuggestionsOutputUtils::outputSuggestionsOfDicNode( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, - const DicNode *const terminalDicNode, const float languageWeight, + const DicNode *const terminalDicNode, const float weightOfLangModelVsSpatialModel, const bool boostExactMatches, const bool forceCommitMultiWords, const bool outputSecondWordFirstLetterInputIndex, SuggestionResults *const outSuggestionResults) { @@ -83,8 +87,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; } const float doubleLetterCost = scoringPolicy->getDoubleLetterDemotionDistanceCost(terminalDicNode); - const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) - + doubleLetterCost; + const float compoundDistance = + terminalDicNode->getCompoundDistance(weightOfLangModelVsSpatialModel) + + doubleLetterCost; const WordAttributes wordAttributes = traverseSession->getDictionaryStructurePolicy() ->getWordAttributesInContext(terminalDicNode->getPrevWordIds(), terminalDicNode->getWordId(), nullptr /* multiBigramMap */); diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.h b/native/jni/src/suggest/core/result/suggestions_output_utils.h index b099b4776..bf8497828 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.h +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.h @@ -33,7 +33,7 @@ class SuggestionsOutputUtils { * Outputs the final list of suggestions (i.e., terminal nodes). */ static void outputSuggestions(const Scoring *const scoringPolicy, - DicTraverseSession *traverseSession, const float languageWeight, + DicTraverseSession *traverseSession, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults); private: @@ -44,7 +44,7 @@ class SuggestionsOutputUtils { static void outputSuggestionsOfDicNode(const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, const DicNode *const terminalDicNode, - const float languageWeight, const bool boostExactMatches, + const float weightOfLangModelVsSpatialModel, const bool boostExactMatches, const bool forceCommitMultiWords, const bool outputSecondWordFirstLetterInputIndex, SuggestionResults *const outSuggestionResults); static void outputShortcuts(BinaryDictionaryShortcutIterator *const shortcutIt, diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 947d41f4b..457414f2b 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -45,7 +45,7 @@ const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; */ void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, const float languageWeight, + int inputSize, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const { PROF_OPEN; PROF_START(0); @@ -68,7 +68,7 @@ void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, PROF_END(1); PROF_START(2); SuggestionsOutputUtils::outputSuggestions( - SCORING, tSession, languageWeight, outSuggestionResults); + SCORING, tSession, weightOfLangModelVsSpatialModel, outSuggestionResults); PROF_END(2); PROF_CLOSE; } diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 788e0314b..65d5918cf 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -49,7 +49,8 @@ class Suggest : public SuggestInterface { AK_FORCE_INLINE virtual ~Suggest() {} void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - const float languageWeight, SuggestionResults *const outSuggestionResults) const; + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Suggest); diff --git a/native/jni/src/suggest/core/suggest_interface.h b/native/jni/src/suggest/core/suggest_interface.h index a6e5aefae..a05aa9c80 100644 --- a/native/jni/src/suggest/core/suggest_interface.h +++ b/native/jni/src/suggest/core/suggest_interface.h @@ -28,7 +28,8 @@ class SuggestInterface { public: virtual void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - const float languageWeight, SuggestionResults *const suggestionResults) const = 0; + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const suggestionResults) const = 0; SuggestInterface() {} virtual ~SuggestInterface() {} private: diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 52c4251f0..0240bcf54 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -33,10 +33,12 @@ class TypingScoring : public Scoring { static const TypingScoring *getInstance() { return &sInstance; } AK_FORCE_INLINE void getMostProbableString(const DicTraverseSession *const traverseSession, - const float languageWeight, SuggestionResults *const outSuggestionResults) const {} + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) const {} - AK_FORCE_INLINE float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, - DicNode *const terminals, const int size) const { + AK_FORCE_INLINE float getAdjustedWeightOfLangModelVsSpatialModel( + DicTraverseSession *const traverseSession, DicNode *const terminals, + const int size) const { return 1.0f; }