diff --git a/java/src/com/android/inputmethod/latin/BinaryDictionary.java b/java/src/com/android/inputmethod/latin/BinaryDictionary.java index 95e1340ed..1b5791809 100644 --- a/java/src/com/android/inputmethod/latin/BinaryDictionary.java +++ b/java/src/com/android/inputmethod/latin/BinaryDictionary.java @@ -174,8 +174,8 @@ public final class BinaryDictionary extends Dictionary { private static native int getFormatVersionNative(long dict); private static native int getProbabilityNative(long dict, int[] word); private static native int getMaxProbabilityOfExactMatchesNative(long dict, int[] word); - private static native int getNgramProbabilityNative(long dict, int[] word0, - boolean isBeginningOfSentence, int[] word1); + private static native int getNgramProbabilityNative(long dict, int[][] prevWordCodePointArrays, + boolean[] isBeginningOfSentenceArray, int[] word); private static native void getWordPropertyNative(long dict, int[] word, boolean isBeginningOfSentence, int[] outCodePoints, boolean[] outFlags, int[] outProbabilityInfo, ArrayList outBigramTargets, @@ -186,7 +186,7 @@ public final class BinaryDictionary extends Dictionary { private static native void getSuggestionsNative(long dict, long proximityInfo, long traverseSession, int[] xCoordinates, int[] yCoordinates, int[] times, int[] pointerIds, int[] inputCodePoints, int inputSize, int[] suggestOptions, - int[] prevWordCodePointArray, boolean isBeginningOfSentence, + int[][] prevWordCodePointArrays, boolean[] isBeginningOfSentenceArray, int[] outputSuggestionCount, int[] outputCodePoints, int[] outputScores, int[] outputIndices, int[] outputTypes, int[] outputAutoCommitFirstWordConfidence, float[] inOutLanguageWeight); @@ -194,10 +194,11 @@ public final class BinaryDictionary extends Dictionary { int[] shortcutTarget, int shortcutProbability, boolean isBeginningOfSentence, boolean isNotAWord, boolean isBlacklisted, int timestamp); private static native boolean removeUnigramEntryNative(long dict, int[] word); - private static native boolean addNgramEntryNative(long dict, int[] word0, - boolean isBeginningOfSentence, int[] word1, int probability, int timestamp); - private static native boolean removeNgramEntryNative(long dict, int[] word0, - boolean isBeginningOfSentence, int[] word1); + private static native boolean addNgramEntryNative(long dict, + int[][] prevWordCodePointArrays, boolean[] isBeginningOfSentenceArray, + int[] word, int probability, int timestamp); + private static native boolean removeNgramEntryNative(long dict, + int[][] prevWordCodePointArrays, boolean[] isBeginningOfSentenceArray, int[] word); private static native int addMultipleDictionaryEntriesNative(long dict, LanguageModelParam[] languageModelParams, int startIndex); private static native String getPropertyNative(long dict, String query); @@ -290,8 +291,8 @@ public final class BinaryDictionary extends Dictionary { getTraverseSession(sessionId).getSession(), inputPointers.getXCoordinates(), inputPointers.getYCoordinates(), inputPointers.getTimes(), inputPointers.getPointerIds(), session.mInputCodePoints, inputSize, - session.mNativeSuggestOptions.getOptions(), session.mPrevWordCodePointArrays[0], - session.mIsBeginningOfSentenceArray[0], session.mOutputSuggestionCount, + session.mNativeSuggestOptions.getOptions(), session.mPrevWordCodePointArrays, + session.mIsBeginningOfSentenceArray, session.mOutputSuggestionCount, session.mOutputCodePoints, session.mOutputScores, session.mSpaceIndices, session.mOutputTypes, session.mOutputAutoCommitFirstWordConfidence, session.mInputOutputLanguageWeight); @@ -359,8 +360,8 @@ public final class BinaryDictionary extends Dictionary { new boolean[Constants.MAX_PREV_WORD_COUNT_FOR_N_GRAM]; prevWordsInfo.outputToArray(prevWordCodePointArrays, isBeginningOfSentenceArray); final int[] wordCodePoints = StringUtils.toCodePointArray(word); - return getNgramProbabilityNative(mNativeDict, prevWordCodePointArrays[0], - isBeginningOfSentenceArray[0], wordCodePoints); + return getNgramProbabilityNative(mNativeDict, prevWordCodePointArrays, + isBeginningOfSentenceArray, wordCodePoints); } public WordProperty getWordProperty(final String word, final boolean isBeginningOfSentence) { @@ -456,8 +457,8 @@ public final class BinaryDictionary extends Dictionary { new boolean[Constants.MAX_PREV_WORD_COUNT_FOR_N_GRAM]; prevWordsInfo.outputToArray(prevWordCodePointArrays, isBeginningOfSentenceArray); final int[] wordCodePoints = StringUtils.toCodePointArray(word); - if (!addNgramEntryNative(mNativeDict, prevWordCodePointArrays[0], - isBeginningOfSentenceArray[0], wordCodePoints, probability, timestamp)) { + if (!addNgramEntryNative(mNativeDict, prevWordCodePointArrays, + isBeginningOfSentenceArray, wordCodePoints, probability, timestamp)) { return false; } mHasUpdated = true; @@ -474,8 +475,8 @@ public final class BinaryDictionary extends Dictionary { new boolean[Constants.MAX_PREV_WORD_COUNT_FOR_N_GRAM]; prevWordsInfo.outputToArray(prevWordCodePointArrays, isBeginningOfSentenceArray); final int[] wordCodePoints = StringUtils.toCodePointArray(word); - if (!removeNgramEntryNative(mNativeDict, prevWordCodePointArrays[0], - isBeginningOfSentenceArray[0], wordCodePoints)) { + if (!removeNgramEntryNative(mNativeDict, prevWordCodePointArrays, + isBeginningOfSentenceArray, wordCodePoints)) { return false; } mHasUpdated = true; diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index 6b4fb7986..3add84a0a 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -178,7 +178,7 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, jlong proximityInfo, jlong dicTraverseSession, jintArray xCoordinatesArray, jintArray yCoordinatesArray, jintArray timesArray, jintArray pointerIdsArray, jintArray inputCodePointsArray, jint inputSize, jintArray suggestOptions, - jintArray prevWordCodePointsForBigrams, jboolean isBeginningOfSentence, + jobjectArray prevWordCodePointArrays, jbooleanArray isBeginningOfSentenceArray, jintArray outSuggestionCount, jintArray outCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray, jfloatArray inOutLanguageWeight) { @@ -201,20 +201,11 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, int pointerIds[inputSize]; const jsize inputCodePointsLength = env->GetArrayLength(inputCodePointsArray); int inputCodePoints[inputCodePointsLength]; - const jsize prevWordCodePointsLength = - prevWordCodePointsForBigrams ? env->GetArrayLength(prevWordCodePointsForBigrams) : 0; - int prevWordCodePointsInternal[prevWordCodePointsLength]; - int *prevWordCodePoints = nullptr; env->GetIntArrayRegion(xCoordinatesArray, 0, inputSize, xCoordinates); env->GetIntArrayRegion(yCoordinatesArray, 0, inputSize, yCoordinates); env->GetIntArrayRegion(timesArray, 0, inputSize, times); env->GetIntArrayRegion(pointerIdsArray, 0, inputSize, pointerIds); env->GetIntArrayRegion(inputCodePointsArray, 0, inputCodePointsLength, inputCodePoints); - if (prevWordCodePointsForBigrams) { - env->GetIntArrayRegion(prevWordCodePointsForBigrams, 0, prevWordCodePointsLength, - prevWordCodePointsInternal); - prevWordCodePoints = prevWordCodePointsInternal; - } const jsize numberOfOptions = env->GetArrayLength(suggestOptions); int options[numberOfOptions]; @@ -248,8 +239,8 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, float languageWeight; env->GetFloatArrayRegion(inOutLanguageWeight, 0, 1 /* len */, &languageWeight); SuggestionResults suggestionResults(MAX_RESULTS); - const PrevWordsInfo prevWordsInfo(prevWordCodePoints, prevWordCodePointsLength, - isBeginningOfSentence); + const PrevWordsInfo prevWordsInfo = JniDataUtils::constructPrevWordsInfo(env, + prevWordCodePointArrays, isBeginningOfSentenceArray); if (givenSuggestOptions.isGesture() || inputSize > 0) { // TODO: Use SuggestionResults to return suggestions. dictionary->getSuggestions(pInfo, traverseSession, xCoordinates, yCoordinates, @@ -391,41 +382,38 @@ static bool latinime_BinaryDictionary_removeUnigramEntry(JNIEnv *env, jclass cla } static bool latinime_BinaryDictionary_addNgramEntry(JNIEnv *env, jclass clazz, jlong dict, - jintArray word0, jboolean isBeginningOfSentence, jintArray word1, jint probability, - jint timestamp) { + jobjectArray prevWordCodePointArrays, jbooleanArray isBeginningOfSentenceArray, + jintArray word, jint probability, jint timestamp) { Dictionary *dictionary = reinterpret_cast(dict); if (!dictionary) { return false; } - jsize word0Length = env->GetArrayLength(word0); - int word0CodePoints[word0Length]; - env->GetIntArrayRegion(word0, 0, word0Length, word0CodePoints); - jsize word1Length = env->GetArrayLength(word1); - int word1CodePoints[word1Length]; - env->GetIntArrayRegion(word1, 0, word1Length, word1CodePoints); + const PrevWordsInfo prevWordsInfo = JniDataUtils::constructPrevWordsInfo(env, + prevWordCodePointArrays, isBeginningOfSentenceArray); + jsize wordLength = env->GetArrayLength(word); + int wordCodePoints[wordLength]; + env->GetIntArrayRegion(word, 0, wordLength, wordCodePoints); const std::vector bigramTargetCodePoints( - word1CodePoints, word1CodePoints + word1Length); + wordCodePoints, wordCodePoints + wordLength); // Use 1 for count to indicate the bigram has inputted. const BigramProperty bigramProperty(&bigramTargetCodePoints, probability, timestamp, 0 /* level */, 1 /* count */); - const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, isBeginningOfSentence); return dictionary->addNgramEntry(&prevWordsInfo, &bigramProperty); } static bool latinime_BinaryDictionary_removeNgramEntry(JNIEnv *env, jclass clazz, jlong dict, - jintArray word0, jboolean isBeginningOfSentence, jintArray word1) { + jobjectArray prevWordCodePointArrays, jbooleanArray isBeginningOfSentenceArray, + jintArray word) { Dictionary *dictionary = reinterpret_cast(dict); if (!dictionary) { return false; } - jsize word0Length = env->GetArrayLength(word0); - int word0CodePoints[word0Length]; - env->GetIntArrayRegion(word0, 0, word0Length, word0CodePoints); - jsize word1Length = env->GetArrayLength(word1); - int word1CodePoints[word1Length]; - env->GetIntArrayRegion(word1, 0, word1Length, word1CodePoints); - const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, isBeginningOfSentence); - return dictionary->removeNgramEntry(&prevWordsInfo, word1CodePoints, word1Length); + const PrevWordsInfo prevWordsInfo = JniDataUtils::constructPrevWordsInfo(env, + prevWordCodePointArrays, isBeginningOfSentenceArray); + jsize wordLength = env->GetArrayLength(word); + int wordCodePoints[wordLength]; + env->GetIntArrayRegion(word, 0, wordLength, wordCodePoints); + return dictionary->removeNgramEntry(&prevWordsInfo, wordCodePoints, wordLength); } // Returns how many language model params are processed. @@ -672,7 +660,7 @@ static const JNINativeMethod sMethods[] = { }, { const_cast("getSuggestionsNative"), - const_cast("(JJJ[I[I[I[I[II[I[IZ[I[I[I[I[I[I[F)V"), + const_cast("(JJJ[I[I[I[I[II[I[[I[Z[I[I[I[I[I[I[F)V"), reinterpret_cast(latinime_BinaryDictionary_getSuggestions) }, { @@ -687,7 +675,7 @@ static const JNINativeMethod sMethods[] = { }, { const_cast("getNgramProbabilityNative"), - const_cast("(J[IZ[I)I"), + const_cast("(J[[I[Z[I)I"), reinterpret_cast(latinime_BinaryDictionary_getNgramProbability) }, { @@ -713,12 +701,12 @@ static const JNINativeMethod sMethods[] = { }, { const_cast("addNgramEntryNative"), - const_cast("(J[IZ[III)Z"), + const_cast("(J[[I[Z[III)Z"), reinterpret_cast(latinime_BinaryDictionary_addNgramEntry) }, { const_cast("removeNgramEntryNative"), - const_cast("(J[IZ[I)Z"), + const_cast("(J[[I[Z[I)Z"), reinterpret_cast(latinime_BinaryDictionary_removeNgramEntry) }, { diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h index 640f6a2fc..e350c6996 100644 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ b/native/jni/src/suggest/core/session/prev_words_info.h @@ -25,7 +25,6 @@ namespace latinime { // TODO: Support n-gram. -// This class does not take ownership of any code point buffers. class PrevWordsInfo { public: // No prev word information. @@ -33,21 +32,52 @@ class PrevWordsInfo { clear(); } + PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) { + for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { + mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i]; + memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i], + sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); + mIsBeginningOfSentence[i] = prevWordsInfo.mIsBeginningOfSentence[i]; + } + } + + // Construct from previous words. + PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH], + const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, + const size_t prevWordCount) { + clear(); + for (size_t i = 0; i < std::min(NELEMS(mPrevWordCodePoints), prevWordCount); ++i) { + if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) { + continue; + } + memmove(mPrevWordCodePoints[i], prevWordCodePoints[i], + sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]); + mPrevWordCodePointCount[i] = prevWordCodePointCount[i]; + mIsBeginningOfSentence[i] = isBeginningOfSentence[i]; + } + } + + // Construct from a previous word. PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount, const bool isBeginningOfSentence) { clear(); - mPrevWordCodePoints[0] = prevWordCodePoints; + if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { + return; + } + memmove(mPrevWordCodePoints[0], prevWordCodePoints, + sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount); mPrevWordCodePointCount[0] = prevWordCodePointCount; mIsBeginningOfSentence[0] = isBeginningOfSentence; } bool isValid() const { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - if (mPrevWordCodePointCount[i] > MAX_WORD_LENGTH) { - return false; - } + if (mPrevWordCodePointCount[0] > 0) { + return true; } - return true; + if (mIsBeginningOfSentence[0]) { + return true; + } + return false; } void getPrevWordsTerminalPtNodePos( @@ -168,13 +198,12 @@ class PrevWordsInfo { void clear() { for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - mPrevWordCodePoints[i] = nullptr; mPrevWordCodePointCount[i] = 0; mIsBeginningOfSentence[i] = false; } } - const int *mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; }; diff --git a/native/jni/src/utils/jni_data_utils.h b/native/jni/src/utils/jni_data_utils.h index 3514aeeb0..cb82d3c3b 100644 --- a/native/jni/src/utils/jni_data_utils.h +++ b/native/jni/src/utils/jni_data_utils.h @@ -21,6 +21,7 @@ #include "defines.h" #include "jni.h" +#include "suggest/core/session/prev_words_info.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" #include "utils/char_utils.h" @@ -95,6 +96,37 @@ class JniDataUtils { } } + static PrevWordsInfo constructPrevWordsInfo(JNIEnv *env, jobjectArray prevWordCodePointArrays, + jbooleanArray isBeginningOfSentenceArray) { + int prevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; + int prevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + bool isBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + jsize prevWordsCount = env->GetArrayLength(prevWordCodePointArrays); + for (size_t i = 0; i < NELEMS(prevWordCodePoints); ++i) { + prevWordCodePointCount[i] = 0; + isBeginningOfSentence[i] = false; + if (prevWordsCount <= static_cast(i)) { + continue; + } + jintArray prevWord = (jintArray)env->GetObjectArrayElement(prevWordCodePointArrays, i); + if (!prevWord) { + continue; + } + jsize prevWordLength = env->GetArrayLength(prevWord); + if (prevWordLength > MAX_WORD_LENGTH) { + continue; + } + env->GetIntArrayRegion(prevWord, 0, prevWordLength, prevWordCodePoints[i]); + prevWordCodePointCount[i] = prevWordLength; + jboolean isBeginningOfSentenceBoolean = JNI_FALSE; + env->GetBooleanArrayRegion(isBeginningOfSentenceArray, i, 1 /* len */, + &isBeginningOfSentenceBoolean); + isBeginningOfSentence[i] = isBeginningOfSentenceBoolean == JNI_TRUE; + } + return PrevWordsInfo(prevWordCodePoints, prevWordCodePointCount, isBeginningOfSentence, + MAX_PREV_WORD_COUNT_FOR_N_GRAM); + } + static void putBooleanToArray(JNIEnv *env, jbooleanArray array, const int index, const jboolean value) { env->SetBooleanArrayRegion(array, index, 1 /* len */, &value);