diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index 11fa3da3a..1dd68ea8b 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -109,7 +109,8 @@ static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring s } Dictionary *dictionary = 0; if (BinaryFormat::UNKNOWN_FORMAT - == BinaryFormat::detectFormat(static_cast(dictBuf))) { + == BinaryFormat::detectFormat(static_cast(dictBuf), + static_cast(dictSize))) { AKLOGE("DICT: dictionary format is unknown, bad magic number"); #ifdef USE_MMAP_FOR_DICTIONARY releaseDictBuf(static_cast(dictBuf) - adjust, adjDictSize, fd); diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h index 06f50dc7f..98241532f 100644 --- a/native/jni/src/binary_format.h +++ b/native/jni/src/binary_format.h @@ -64,13 +64,14 @@ class BinaryFormat { static const int UNKNOWN_FORMAT = -1; static const int SHORTCUT_LIST_SIZE_SIZE = 2; - static int detectFormat(const uint8_t *const dict); - static int getHeaderSize(const uint8_t *const dict); - static int getFlags(const uint8_t *const dict); + static int detectFormat(const uint8_t *const dict, const int dictSize); + static int getHeaderSize(const uint8_t *const dict, const int dictSize); + static int getFlags(const uint8_t *const dict, const int dictSize); static bool hasBlacklistedOrNotAWordFlag(const int flags); - static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue, - const int outValueSize); - static int readHeaderValueInt(const uint8_t *const dict, const char *const key); + static void readHeaderValue(const uint8_t *const dict, const int dictSize, + const char *const key, int *outValue, const int outValueSize); + static int readHeaderValueInt(const uint8_t *const dict, const int dictSize, + const char *const key); static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos); static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos); static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos); @@ -96,7 +97,7 @@ class BinaryFormat { const uint8_t *bigramFilter, const int unigramProbability); static int getBigramProbabilityFromHashMap(const int position, const hash_map_compat *bigramMap, const int unigramProbability); - static float getMultiWordCostMultiplier(const uint8_t *const dict); + static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize); static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, hash_map_compat *bigramMap); static int getBigramProbability(const uint8_t *const root, int position, @@ -122,6 +123,8 @@ class BinaryFormat { static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; + // Any file smaller than this is not a dictionary. + static const int DICTIONARY_MINIMUM_SIZE = 4; // Originally, format version 1 had a 16-bit magic number, then the version number `01' // then options that must be 0. Hence the first 32-bits of the format are always as follow // and it's okay to consider them a magic number as a whole. @@ -131,6 +134,8 @@ class BinaryFormat { // number, so we had to change it so that version 2 files would be rejected by older // implementations. On this occasion, we made the magic number 32 bits long. static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE + // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 + static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12; static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1; static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; @@ -141,8 +146,11 @@ class BinaryFormat { static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos); }; -AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { +AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) { // The magic number is stored big-endian. + // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't + // understand this format. + if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT; const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3]; switch (magicNumber) { case FORMAT_VERSION_1_MAGIC_NUMBER: @@ -152,6 +160,10 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { // Options (2 bytes) must be 0x00 0x00 return 1; case FORMAT_VERSION_2_MAGIC_NUMBER: + // Version 2 dictionaries are at least 12 bytes long (see below details for the header). + // If this dictionary has the version 2 magic number but is less than 12 bytes long, then + // it's an unknown format and we need to avoid confidently reading the next bytes. + if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT; // Format 2 header is as follows: // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE // Version number (2 bytes) 0x00 0x02 @@ -163,8 +175,8 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { } } -inline int BinaryFormat::getFlags(const uint8_t *const dict) { - switch (detectFormat(dict)) { +inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) { + switch (detectFormat(dict, dictSize)) { case 1: return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else? default: @@ -176,8 +188,8 @@ inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) { return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0; } -inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) { - switch (detectFormat(dict)) { +inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) { + switch (detectFormat(dict, dictSize)) { case 1: return FORMAT_VERSION_1_HEADER_SIZE; case 2: @@ -188,12 +200,12 @@ inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) { } } -inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key, - int *outValue, const int outValueSize) { +inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize, + const char *const key, int *outValue, const int outValueSize) { int outValueIndex = 0; // Only format 2 and above have header attributes as {key,value} string pairs. For prior // formats, we just return an empty string, as if the key wasn't found. - if (2 <= detectFormat(dict)) { + if (2 <= detectFormat(dict, dictSize)) { const int headerOptionsOffset = 4 /* magic number */ + 2 /* dictionary version */ + 2 /* flags */; const int headerSize = @@ -236,11 +248,12 @@ inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char if (outValueIndex >= 0) outValue[outValueIndex] = 0; } -inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) { +inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize, + const char *const key) { const int bufferSize = LARGEST_INT_DIGIT_COUNT; int intBuffer[bufferSize]; char charBuffer[bufferSize]; - BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize); + BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize); for (int i = 0; i < bufferSize; ++i) { charBuffer[i] = intBuffer[i]; } @@ -256,8 +269,10 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t * return ((msb & 0x7F) << 8) | dict[(*pos)++]; } -inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) { - const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE"); +inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict, + const int dictSize) { + const int headerValue = readHeaderValueInt(dict, dictSize, + "MULTIPLE_WORDS_DEMOTION_RATE"); if (headerValue == S_INT_MIN) { return 1.0f; } diff --git a/native/jni/src/dictionary.cpp b/native/jni/src/dictionary.cpp index c998c0676..dadb2bab2 100644 --- a/native/jni/src/dictionary.cpp +++ b/native/jni/src/dictionary.cpp @@ -34,9 +34,11 @@ namespace latinime { Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust) : mDict(static_cast(dict)), - mOffsetDict((static_cast(dict)) + BinaryFormat::getHeaderSize(mDict)), + mOffsetDict((static_cast(dict)) + + BinaryFormat::getHeaderSize(mDict, dictSize)), mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust), - mUnigramDictionary(new UnigramDictionary(mOffsetDict, BinaryFormat::getFlags(mDict))), + mUnigramDictionary(new UnigramDictionary(mOffsetDict, + BinaryFormat::getFlags(mDict, dictSize))), mBigramDictionary(new BigramDictionary(mOffsetDict)), mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())), mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) { diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 51165858b..6408f0163 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -64,7 +64,8 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer; void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, int prevWordLength) { mDictionary = dictionary; - mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict()); + mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(), + mDictionary->getDictSize()); if (!prevWord) { mPrevWordPos = NOT_VALID_WORD; return;