Be careful about the dictionary size in detection methods
Bug: 8857618 Change-Id: I29345ec96d53da601571ba73197a6485643a10a7main
parent
1eb1af75a7
commit
03f8c6aed3
|
@ -109,7 +109,8 @@ static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring s
|
||||||
}
|
}
|
||||||
Dictionary *dictionary = 0;
|
Dictionary *dictionary = 0;
|
||||||
if (BinaryFormat::UNKNOWN_FORMAT
|
if (BinaryFormat::UNKNOWN_FORMAT
|
||||||
== BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf))) {
|
== BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf),
|
||||||
|
static_cast<int>(dictSize))) {
|
||||||
AKLOGE("DICT: dictionary format is unknown, bad magic number");
|
AKLOGE("DICT: dictionary format is unknown, bad magic number");
|
||||||
#ifdef USE_MMAP_FOR_DICTIONARY
|
#ifdef USE_MMAP_FOR_DICTIONARY
|
||||||
releaseDictBuf(static_cast<const char *>(dictBuf) - adjust, adjDictSize, fd);
|
releaseDictBuf(static_cast<const char *>(dictBuf) - adjust, adjDictSize, fd);
|
||||||
|
|
|
@ -64,13 +64,14 @@ class BinaryFormat {
|
||||||
static const int UNKNOWN_FORMAT = -1;
|
static const int UNKNOWN_FORMAT = -1;
|
||||||
static const int SHORTCUT_LIST_SIZE_SIZE = 2;
|
static const int SHORTCUT_LIST_SIZE_SIZE = 2;
|
||||||
|
|
||||||
static int detectFormat(const uint8_t *const dict);
|
static int detectFormat(const uint8_t *const dict, const int dictSize);
|
||||||
static int getHeaderSize(const uint8_t *const dict);
|
static int getHeaderSize(const uint8_t *const dict, const int dictSize);
|
||||||
static int getFlags(const uint8_t *const dict);
|
static int getFlags(const uint8_t *const dict, const int dictSize);
|
||||||
static bool hasBlacklistedOrNotAWordFlag(const int flags);
|
static bool hasBlacklistedOrNotAWordFlag(const int flags);
|
||||||
static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue,
|
static void readHeaderValue(const uint8_t *const dict, const int dictSize,
|
||||||
const int outValueSize);
|
const char *const key, int *outValue, const int outValueSize);
|
||||||
static int readHeaderValueInt(const uint8_t *const dict, const char *const key);
|
static int readHeaderValueInt(const uint8_t *const dict, const int dictSize,
|
||||||
|
const char *const key);
|
||||||
static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos);
|
static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos);
|
||||||
static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos);
|
static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos);
|
||||||
static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos);
|
static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos);
|
||||||
|
@ -96,7 +97,7 @@ class BinaryFormat {
|
||||||
const uint8_t *bigramFilter, const int unigramProbability);
|
const uint8_t *bigramFilter, const int unigramProbability);
|
||||||
static int getBigramProbabilityFromHashMap(const int position,
|
static int getBigramProbabilityFromHashMap(const int position,
|
||||||
const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
|
const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
|
||||||
static float getMultiWordCostMultiplier(const uint8_t *const dict);
|
static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize);
|
||||||
static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
|
static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
|
||||||
hash_map_compat<int, int> *bigramMap);
|
hash_map_compat<int, int> *bigramMap);
|
||||||
static int getBigramProbability(const uint8_t *const root, int position,
|
static int getBigramProbability(const uint8_t *const root, int position,
|
||||||
|
@ -122,6 +123,8 @@ class BinaryFormat {
|
||||||
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20;
|
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20;
|
||||||
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30;
|
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30;
|
||||||
|
|
||||||
|
// Any file smaller than this is not a dictionary.
|
||||||
|
static const int DICTIONARY_MINIMUM_SIZE = 4;
|
||||||
// Originally, format version 1 had a 16-bit magic number, then the version number `01'
|
// Originally, format version 1 had a 16-bit magic number, then the version number `01'
|
||||||
// then options that must be 0. Hence the first 32-bits of the format are always as follow
|
// then options that must be 0. Hence the first 32-bits of the format are always as follow
|
||||||
// and it's okay to consider them a magic number as a whole.
|
// and it's okay to consider them a magic number as a whole.
|
||||||
|
@ -131,6 +134,8 @@ class BinaryFormat {
|
||||||
// number, so we had to change it so that version 2 files would be rejected by older
|
// number, so we had to change it so that version 2 files would be rejected by older
|
||||||
// implementations. On this occasion, we made the magic number 32 bits long.
|
// implementations. On this occasion, we made the magic number 32 bits long.
|
||||||
static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE
|
static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE
|
||||||
|
// Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12
|
||||||
|
static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12;
|
||||||
|
|
||||||
static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1;
|
static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1;
|
||||||
static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20;
|
static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20;
|
||||||
|
@ -141,8 +146,11 @@ class BinaryFormat {
|
||||||
static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos);
|
static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos);
|
||||||
};
|
};
|
||||||
|
|
||||||
AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
|
AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) {
|
||||||
// The magic number is stored big-endian.
|
// The magic number is stored big-endian.
|
||||||
|
// If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't
|
||||||
|
// understand this format.
|
||||||
|
if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT;
|
||||||
const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3];
|
const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3];
|
||||||
switch (magicNumber) {
|
switch (magicNumber) {
|
||||||
case FORMAT_VERSION_1_MAGIC_NUMBER:
|
case FORMAT_VERSION_1_MAGIC_NUMBER:
|
||||||
|
@ -152,6 +160,10 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
|
||||||
// Options (2 bytes) must be 0x00 0x00
|
// Options (2 bytes) must be 0x00 0x00
|
||||||
return 1;
|
return 1;
|
||||||
case FORMAT_VERSION_2_MAGIC_NUMBER:
|
case FORMAT_VERSION_2_MAGIC_NUMBER:
|
||||||
|
// Version 2 dictionaries are at least 12 bytes long (see below details for the header).
|
||||||
|
// If this dictionary has the version 2 magic number but is less than 12 bytes long, then
|
||||||
|
// it's an unknown format and we need to avoid confidently reading the next bytes.
|
||||||
|
if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT;
|
||||||
// Format 2 header is as follows:
|
// Format 2 header is as follows:
|
||||||
// Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE
|
// Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE
|
||||||
// Version number (2 bytes) 0x00 0x02
|
// Version number (2 bytes) 0x00 0x02
|
||||||
|
@ -163,8 +175,8 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int BinaryFormat::getFlags(const uint8_t *const dict) {
|
inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) {
|
||||||
switch (detectFormat(dict)) {
|
switch (detectFormat(dict, dictSize)) {
|
||||||
case 1:
|
case 1:
|
||||||
return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else?
|
return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else?
|
||||||
default:
|
default:
|
||||||
|
@ -176,8 +188,8 @@ inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) {
|
||||||
return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
|
return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
|
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) {
|
||||||
switch (detectFormat(dict)) {
|
switch (detectFormat(dict, dictSize)) {
|
||||||
case 1:
|
case 1:
|
||||||
return FORMAT_VERSION_1_HEADER_SIZE;
|
return FORMAT_VERSION_1_HEADER_SIZE;
|
||||||
case 2:
|
case 2:
|
||||||
|
@ -188,12 +200,12 @@ inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key,
|
inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize,
|
||||||
int *outValue, const int outValueSize) {
|
const char *const key, int *outValue, const int outValueSize) {
|
||||||
int outValueIndex = 0;
|
int outValueIndex = 0;
|
||||||
// Only format 2 and above have header attributes as {key,value} string pairs. For prior
|
// Only format 2 and above have header attributes as {key,value} string pairs. For prior
|
||||||
// formats, we just return an empty string, as if the key wasn't found.
|
// formats, we just return an empty string, as if the key wasn't found.
|
||||||
if (2 <= detectFormat(dict)) {
|
if (2 <= detectFormat(dict, dictSize)) {
|
||||||
const int headerOptionsOffset = 4 /* magic number */
|
const int headerOptionsOffset = 4 /* magic number */
|
||||||
+ 2 /* dictionary version */ + 2 /* flags */;
|
+ 2 /* dictionary version */ + 2 /* flags */;
|
||||||
const int headerSize =
|
const int headerSize =
|
||||||
|
@ -236,11 +248,12 @@ inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char
|
||||||
if (outValueIndex >= 0) outValue[outValueIndex] = 0;
|
if (outValueIndex >= 0) outValue[outValueIndex] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) {
|
inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize,
|
||||||
|
const char *const key) {
|
||||||
const int bufferSize = LARGEST_INT_DIGIT_COUNT;
|
const int bufferSize = LARGEST_INT_DIGIT_COUNT;
|
||||||
int intBuffer[bufferSize];
|
int intBuffer[bufferSize];
|
||||||
char charBuffer[bufferSize];
|
char charBuffer[bufferSize];
|
||||||
BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize);
|
BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize);
|
||||||
for (int i = 0; i < bufferSize; ++i) {
|
for (int i = 0; i < bufferSize; ++i) {
|
||||||
charBuffer[i] = intBuffer[i];
|
charBuffer[i] = intBuffer[i];
|
||||||
}
|
}
|
||||||
|
@ -256,8 +269,10 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *
|
||||||
return ((msb & 0x7F) << 8) | dict[(*pos)++];
|
return ((msb & 0x7F) << 8) | dict[(*pos)++];
|
||||||
}
|
}
|
||||||
|
|
||||||
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) {
|
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict,
|
||||||
const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE");
|
const int dictSize) {
|
||||||
|
const int headerValue = readHeaderValueInt(dict, dictSize,
|
||||||
|
"MULTIPLE_WORDS_DEMOTION_RATE");
|
||||||
if (headerValue == S_INT_MIN) {
|
if (headerValue == S_INT_MIN) {
|
||||||
return 1.0f;
|
return 1.0f;
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,9 +34,11 @@ namespace latinime {
|
||||||
|
|
||||||
Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust)
|
Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust)
|
||||||
: mDict(static_cast<unsigned char *>(dict)),
|
: mDict(static_cast<unsigned char *>(dict)),
|
||||||
mOffsetDict((static_cast<unsigned char *>(dict)) + BinaryFormat::getHeaderSize(mDict)),
|
mOffsetDict((static_cast<unsigned char *>(dict))
|
||||||
|
+ BinaryFormat::getHeaderSize(mDict, dictSize)),
|
||||||
mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust),
|
mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust),
|
||||||
mUnigramDictionary(new UnigramDictionary(mOffsetDict, BinaryFormat::getFlags(mDict))),
|
mUnigramDictionary(new UnigramDictionary(mOffsetDict,
|
||||||
|
BinaryFormat::getFlags(mDict, dictSize))),
|
||||||
mBigramDictionary(new BigramDictionary(mOffsetDict)),
|
mBigramDictionary(new BigramDictionary(mOffsetDict)),
|
||||||
mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())),
|
mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())),
|
||||||
mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) {
|
mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) {
|
||||||
|
|
|
@ -64,7 +64,8 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer;
|
||||||
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
|
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
|
||||||
int prevWordLength) {
|
int prevWordLength) {
|
||||||
mDictionary = dictionary;
|
mDictionary = dictionary;
|
||||||
mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict());
|
mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(),
|
||||||
|
mDictionary->getDictSize());
|
||||||
if (!prevWord) {
|
if (!prevWord) {
|
||||||
mPrevWordPos = NOT_VALID_WORD;
|
mPrevWordPos = NOT_VALID_WORD;
|
||||||
return;
|
return;
|
||||||
|
|
Loading…
Reference in New Issue