diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 377015371..41ef9d2b2 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -143,7 +143,7 @@ class DicNode { dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), dicNode->getOutputWordBuf(), dicNode->mDicNodeProperties.getDepth(), - dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevSpacePositions, + dicNode->mDicNodeState.mDicNodeStatePrevWord.getSecondWordFirstInputIndex(), mDicNodeState.mDicNodeStateInput.getInputIndex(0) /* lastInputIndex */); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } @@ -321,8 +321,13 @@ class DicNode { DUMP_WORD_AND_SCORE("OUTPUT"); } - void outputSpacePositionsResult(int *spaceIndices) const { - mDicNodeState.mDicNodeStatePrevWord.outputSpacePositions(spaceIndices); + int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const { + const int inputIndex = mDicNodeState.mDicNodeStatePrevWord.getSecondWordFirstInputIndex(); + if (inputIndex == NOT_AN_INDEX) { + return NOT_AN_INDEX; + } else { + return pInfoState->getInputIndexOfSampledPoint(inputIndex); + } } bool hasMultipleWords() const { @@ -573,7 +578,11 @@ class DicNode { } } - AK_FORCE_INLINE void updateInputIndexG(DicNode_InputStateG *inputStateG) { + AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) { + if (mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() == 1 && isFirstLetter()) { + mDicNodeState.mDicNodeStatePrevWord.setSecondWordFirstInputIndex( + inputStateG->mInputIndex); + } mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, inputStateG->mInputIndex, inputStateG->mPrevCodePoint, inputStateG->mTerminalDiffCost, inputStateG->mRawLength); diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h index b7af97018..b8986203d 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/core/dicnode/dic_node_utils.h" +#include "suggest/core/layout/proximity_info_state.h" namespace latinime { @@ -29,9 +30,8 @@ class DicNodeStatePrevWord { public: AK_FORCE_INLINE DicNodeStatePrevWord() : mPrevWordCount(0), mPrevWordLength(0), mPrevWordStart(0), mPrevWordProbability(0), - mPrevWordNodePos(NOT_A_DICT_POS) { + mPrevWordNodePos(NOT_A_DICT_POS), mSecondWordFirstInputIndex(NOT_AN_INDEX) { memset(mPrevWord, 0, sizeof(mPrevWord)); - memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); } virtual ~DicNodeStatePrevWord() {} @@ -42,7 +42,7 @@ class DicNodeStatePrevWord { mPrevWordStart = 0; mPrevWordProbability = -1; mPrevWordNodePos = NOT_A_DICT_POS; - memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); + mSecondWordFirstInputIndex = NOT_AN_INDEX; } void init(const int prevWordNodePos) { @@ -51,7 +51,7 @@ class DicNodeStatePrevWord { mPrevWordStart = 0; mPrevWordProbability = -1; mPrevWordNodePos = prevWordNodePos; - memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); + mSecondWordFirstInputIndex = NOT_AN_INDEX; } // Init by copy @@ -61,14 +61,14 @@ class DicNodeStatePrevWord { mPrevWordStart = prevWord->mPrevWordStart; mPrevWordProbability = prevWord->mPrevWordProbability; mPrevWordNodePos = prevWord->mPrevWordNodePos; + mSecondWordFirstInputIndex = prevWord->mSecondWordFirstInputIndex; memcpy(mPrevWord, prevWord->mPrevWord, prevWord->mPrevWordLength * sizeof(mPrevWord[0])); - memcpy(mPrevSpacePositions, prevWord->mPrevSpacePositions, sizeof(mPrevSpacePositions)); } void init(const int16_t prevWordCount, const int16_t prevWordProbability, const int prevWordNodePos, const int *const src0, const int16_t length0, - const int *const src1, const int16_t length1, const int *const prevSpacePositions, - const int lastInputIndex) { + const int *const src1, const int16_t length1, + const int prevWordSecondWordFirstInputIndex, const int lastInputIndex) { mPrevWordCount = min(prevWordCount, static_cast(MAX_RESULTS)); mPrevWordProbability = prevWordProbability; mPrevWordNodePos = prevWordNodePos; @@ -80,8 +80,7 @@ class DicNodeStatePrevWord { mPrevWord[twoWordsLen] = KEYCODE_SPACE; mPrevWordStart = length0; mPrevWordLength = static_cast(twoWordsLen + 1); - memcpy(mPrevSpacePositions, prevSpacePositions, sizeof(mPrevSpacePositions)); - mPrevSpacePositions[mPrevWordCount - 1] = lastInputIndex; + mSecondWordFirstInputIndex = prevWordSecondWordFirstInputIndex; } void truncate(const int offset) { @@ -96,11 +95,12 @@ class DicNodeStatePrevWord { mPrevWordLength = newPrevWordLength; } - void outputSpacePositions(int *spaceIndices) const { - // Convert uint16_t to int - for (int i = 0; i < MAX_RESULTS; i++) { - spaceIndices[i] = mPrevSpacePositions[i]; - } + void setSecondWordFirstInputIndex(const int inputIndex) { + mSecondWordFirstInputIndex = inputIndex; + } + + int getSecondWordFirstInputIndex() const { + return mSecondWordFirstInputIndex; } // TODO: remove @@ -138,8 +138,6 @@ class DicNodeStatePrevWord { // TODO: Move to private int mPrevWord[MAX_WORD_LENGTH]; - // TODO: Move to private - int mPrevSpacePositions[MAX_RESULTS]; private: // Caution!!! @@ -150,6 +148,7 @@ class DicNodeStatePrevWord { int16_t mPrevWordStart; int16_t mPrevWordProbability; int mPrevWordNodePos; + int mSecondWordFirstInputIndex; }; } // namespace latinime #endif // LATINIME_DIC_NODE_STATE_PREVWORD_H diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.h b/native/jni/src/suggest/core/layout/proximity_info_state.h index 01bf81864..c94060fa9 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state.h @@ -130,6 +130,10 @@ class ProximityInfoState { return mSampledInputYs[index]; } + int getInputIndexOfSampledPoint(const int sampledIndex) const { + return mSampledInputIndice[sampledIndex]; + } + bool hasSpaceProximity(const int index) const; int getLengthCache(const int index) const { diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index e2ef5fc76..e0b1c67d9 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -113,7 +113,9 @@ class DicTraverseSession { if (usedPointerCount != 1) { return false; } - *pointerId = usedPointerId; + if (pointerId) { + *pointerId = usedPointerId; + } return true; } diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index e788e914a..0c925be25 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -117,7 +117,7 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo * Outputs the final list of suggestions (i.e., terminal nodes). */ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, - int *outputCodePoints, int *spaceIndices, int *outputTypes) const { + int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes) const { #if DEBUG_EVALUATE_MOST_PROBABLE_STRING const int terminalSize = 0; #else @@ -139,6 +139,7 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen SCORING->getMostProbableString(traverseSession, terminalSize, languageWeight, &outputCodePoints[0], &outputTypes[0], &frequencies[0]); if (hasMostProbableString) { + outputIndicesToPartialCommit[outputWordIndex] = NOT_AN_INDEX; ++outputWordIndex; } @@ -160,6 +161,9 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen || (traverseSession->getInputSize() >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT && terminals[0].hasMultipleWords())) : false; + // TODO: have partial commit work even with multiple pointers. + const bool outputSecondWordFirstLetterInputIndex = + traverseSession->isOnlyOnePointerUsed(0 /* pointerId */); // Output suggestion results here for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; ++terminalIndex) { @@ -194,18 +198,21 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen terminalDicNode->isExactMatch() || (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) || (isValidWord && SCORING->doesAutoCorrectValidWord())); - maxScore = max(maxScore, finalScore); - - // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. - // Index for top typing suggestion should be 0. - if (isValidWord && outputWordIndex == 0) { - terminalDicNode->outputSpacePositionsResult(spaceIndices); + if (maxScore < finalScore && isValidWord) { + maxScore = finalScore; } // Don't output invalid words. However, we still need to submit their shortcuts if any. if (isValidWord) { outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION | outputTypeFlags; frequencies[outputWordIndex] = finalScore; + if (outputSecondWordFirstLetterInputIndex) { + outputIndicesToPartialCommit[outputWordIndex] = + terminalDicNode->getSecondWordFirstInputIndex( + traverseSession->getProximityInfoState(0)); + } else { + outputIndicesToPartialCommit[outputWordIndex] = NOT_AN_INDEX; + } // Populate the outputChars array with the suggested word. const int startIndex = outputWordIndex * MAX_WORD_LENGTH; terminalDicNode->outputResult(&outputCodePoints[startIndex]); @@ -220,8 +227,19 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen // Shortcut is not supported for multiple words suggestions. // TODO: Check shortcuts during traversal for multiple words suggestions. const bool sameAsTyped = TRAVERSAL->sameAsTyped(traverseSession, terminalDicNode); - outputWordIndex = ShortcutUtils::outputShortcuts(&shortcutIt, outputWordIndex, - finalScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); + const int updatedOutputWordIndex = ShortcutUtils::outputShortcuts(&shortcutIt, + outputWordIndex, finalScore, outputCodePoints, frequencies, outputTypes, + sameAsTyped); + const int secondWordFirstInputIndex = terminalDicNode->getSecondWordFirstInputIndex( + traverseSession->getProximityInfoState(0)); + for (int i = outputWordIndex; i < updatedOutputWordIndex; ++i) { + if (outputSecondWordFirstLetterInputIndex) { + outputIndicesToPartialCommit[i] = secondWordFirstInputIndex; + } else { + outputIndicesToPartialCommit[i] = NOT_AN_INDEX; + } + } + outputWordIndex = updatedOutputWordIndex; } DicNode::managedDelete(terminalDicNode); } diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 875cbe4e0..b24019632 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -55,7 +55,7 @@ class Suggest : public SuggestInterface { void createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, const bool spaceSubstitution) const; int outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, - int *outputCodePoints, int *outputIndices, int *outputTypes) const; + int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes) const; void initializeSearch(DicTraverseSession *traverseSession, int commitPoint) const; void expandCurrentDicNodes(DicTraverseSession *traverseSession) const; void processTerminalDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const;