Be careful about the dictionary size in detection methods

Bug: 8857618
Change-Id: I29345ec96d53da601571ba73197a6485643a10a7
main
Jean Chalard 2013-05-08 15:24:20 +09:00
parent 1eb1af75a7
commit 03f8c6aed3
4 changed files with 42 additions and 23 deletions

View File

@ -109,7 +109,8 @@ static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring s
} }
Dictionary *dictionary = 0; Dictionary *dictionary = 0;
if (BinaryFormat::UNKNOWN_FORMAT if (BinaryFormat::UNKNOWN_FORMAT
== BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf))) { == BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf),
static_cast<int>(dictSize))) {
AKLOGE("DICT: dictionary format is unknown, bad magic number"); AKLOGE("DICT: dictionary format is unknown, bad magic number");
#ifdef USE_MMAP_FOR_DICTIONARY #ifdef USE_MMAP_FOR_DICTIONARY
releaseDictBuf(static_cast<const char *>(dictBuf) - adjust, adjDictSize, fd); releaseDictBuf(static_cast<const char *>(dictBuf) - adjust, adjDictSize, fd);

View File

@ -64,13 +64,14 @@ class BinaryFormat {
static const int UNKNOWN_FORMAT = -1; static const int UNKNOWN_FORMAT = -1;
static const int SHORTCUT_LIST_SIZE_SIZE = 2; static const int SHORTCUT_LIST_SIZE_SIZE = 2;
static int detectFormat(const uint8_t *const dict); static int detectFormat(const uint8_t *const dict, const int dictSize);
static int getHeaderSize(const uint8_t *const dict); static int getHeaderSize(const uint8_t *const dict, const int dictSize);
static int getFlags(const uint8_t *const dict); static int getFlags(const uint8_t *const dict, const int dictSize);
static bool hasBlacklistedOrNotAWordFlag(const int flags); static bool hasBlacklistedOrNotAWordFlag(const int flags);
static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue, static void readHeaderValue(const uint8_t *const dict, const int dictSize,
const int outValueSize); const char *const key, int *outValue, const int outValueSize);
static int readHeaderValueInt(const uint8_t *const dict, const char *const key); static int readHeaderValueInt(const uint8_t *const dict, const int dictSize,
const char *const key);
static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos); static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos);
static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos); static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos);
static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos); static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos);
@ -96,7 +97,7 @@ class BinaryFormat {
const uint8_t *bigramFilter, const int unigramProbability); const uint8_t *bigramFilter, const int unigramProbability);
static int getBigramProbabilityFromHashMap(const int position, static int getBigramProbabilityFromHashMap(const int position,
const hash_map_compat<int, int> *bigramMap, const int unigramProbability); const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
static float getMultiWordCostMultiplier(const uint8_t *const dict); static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize);
static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
hash_map_compat<int, int> *bigramMap); hash_map_compat<int, int> *bigramMap);
static int getBigramProbability(const uint8_t *const root, int position, static int getBigramProbability(const uint8_t *const root, int position,
@ -122,6 +123,8 @@ class BinaryFormat {
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20;
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30;
// Any file smaller than this is not a dictionary.
static const int DICTIONARY_MINIMUM_SIZE = 4;
// Originally, format version 1 had a 16-bit magic number, then the version number `01' // Originally, format version 1 had a 16-bit magic number, then the version number `01'
// then options that must be 0. Hence the first 32-bits of the format are always as follow // then options that must be 0. Hence the first 32-bits of the format are always as follow
// and it's okay to consider them a magic number as a whole. // and it's okay to consider them a magic number as a whole.
@ -131,6 +134,8 @@ class BinaryFormat {
// number, so we had to change it so that version 2 files would be rejected by older // number, so we had to change it so that version 2 files would be rejected by older
// implementations. On this occasion, we made the magic number 32 bits long. // implementations. On this occasion, we made the magic number 32 bits long.
static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE
// Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12
static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12;
static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1; static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1;
static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20;
@ -141,8 +146,11 @@ class BinaryFormat {
static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos); static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos);
}; };
AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) {
// The magic number is stored big-endian. // The magic number is stored big-endian.
// If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't
// understand this format.
if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT;
const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3]; const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3];
switch (magicNumber) { switch (magicNumber) {
case FORMAT_VERSION_1_MAGIC_NUMBER: case FORMAT_VERSION_1_MAGIC_NUMBER:
@ -152,6 +160,10 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
// Options (2 bytes) must be 0x00 0x00 // Options (2 bytes) must be 0x00 0x00
return 1; return 1;
case FORMAT_VERSION_2_MAGIC_NUMBER: case FORMAT_VERSION_2_MAGIC_NUMBER:
// Version 2 dictionaries are at least 12 bytes long (see below details for the header).
// If this dictionary has the version 2 magic number but is less than 12 bytes long, then
// it's an unknown format and we need to avoid confidently reading the next bytes.
if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT;
// Format 2 header is as follows: // Format 2 header is as follows:
// Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE
// Version number (2 bytes) 0x00 0x02 // Version number (2 bytes) 0x00 0x02
@ -163,8 +175,8 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
} }
} }
inline int BinaryFormat::getFlags(const uint8_t *const dict) { inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) {
switch (detectFormat(dict)) { switch (detectFormat(dict, dictSize)) {
case 1: case 1:
return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else? return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else?
default: default:
@ -176,8 +188,8 @@ inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) {
return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0; return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
} }
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) { inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) {
switch (detectFormat(dict)) { switch (detectFormat(dict, dictSize)) {
case 1: case 1:
return FORMAT_VERSION_1_HEADER_SIZE; return FORMAT_VERSION_1_HEADER_SIZE;
case 2: case 2:
@ -188,12 +200,12 @@ inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
} }
} }
inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key, inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize,
int *outValue, const int outValueSize) { const char *const key, int *outValue, const int outValueSize) {
int outValueIndex = 0; int outValueIndex = 0;
// Only format 2 and above have header attributes as {key,value} string pairs. For prior // Only format 2 and above have header attributes as {key,value} string pairs. For prior
// formats, we just return an empty string, as if the key wasn't found. // formats, we just return an empty string, as if the key wasn't found.
if (2 <= detectFormat(dict)) { if (2 <= detectFormat(dict, dictSize)) {
const int headerOptionsOffset = 4 /* magic number */ const int headerOptionsOffset = 4 /* magic number */
+ 2 /* dictionary version */ + 2 /* flags */; + 2 /* dictionary version */ + 2 /* flags */;
const int headerSize = const int headerSize =
@ -236,11 +248,12 @@ inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char
if (outValueIndex >= 0) outValue[outValueIndex] = 0; if (outValueIndex >= 0) outValue[outValueIndex] = 0;
} }
inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) { inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize,
const char *const key) {
const int bufferSize = LARGEST_INT_DIGIT_COUNT; const int bufferSize = LARGEST_INT_DIGIT_COUNT;
int intBuffer[bufferSize]; int intBuffer[bufferSize];
char charBuffer[bufferSize]; char charBuffer[bufferSize];
BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize); BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize);
for (int i = 0; i < bufferSize; ++i) { for (int i = 0; i < bufferSize; ++i) {
charBuffer[i] = intBuffer[i]; charBuffer[i] = intBuffer[i];
} }
@ -256,8 +269,10 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *
return ((msb & 0x7F) << 8) | dict[(*pos)++]; return ((msb & 0x7F) << 8) | dict[(*pos)++];
} }
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) { inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict,
const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE"); const int dictSize) {
const int headerValue = readHeaderValueInt(dict, dictSize,
"MULTIPLE_WORDS_DEMOTION_RATE");
if (headerValue == S_INT_MIN) { if (headerValue == S_INT_MIN) {
return 1.0f; return 1.0f;
} }

View File

@ -34,9 +34,11 @@ namespace latinime {
Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust) Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust)
: mDict(static_cast<unsigned char *>(dict)), : mDict(static_cast<unsigned char *>(dict)),
mOffsetDict((static_cast<unsigned char *>(dict)) + BinaryFormat::getHeaderSize(mDict)), mOffsetDict((static_cast<unsigned char *>(dict))
+ BinaryFormat::getHeaderSize(mDict, dictSize)),
mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust), mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust),
mUnigramDictionary(new UnigramDictionary(mOffsetDict, BinaryFormat::getFlags(mDict))), mUnigramDictionary(new UnigramDictionary(mOffsetDict,
BinaryFormat::getFlags(mDict, dictSize))),
mBigramDictionary(new BigramDictionary(mOffsetDict)), mBigramDictionary(new BigramDictionary(mOffsetDict)),
mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())), mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())),
mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) { mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) {

View File

@ -64,7 +64,8 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer;
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
int prevWordLength) { int prevWordLength) {
mDictionary = dictionary; mDictionary = dictionary;
mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict()); mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(),
mDictionary->getDictSize());
if (!prevWord) { if (!prevWord) {
mPrevWordPos = NOT_VALID_WORD; mPrevWordPos = NOT_VALID_WORD;
return; return;