Merge "Be careful about the dictionary size in detection methods"
This commit is contained in:
commit
5064ac8855
4 changed files with 42 additions and 23 deletions
|
@ -109,7 +109,8 @@ static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring s
|
|||
}
|
||||
Dictionary *dictionary = 0;
|
||||
if (BinaryFormat::UNKNOWN_FORMAT
|
||||
== BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf))) {
|
||||
== BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf),
|
||||
static_cast<int>(dictSize))) {
|
||||
AKLOGE("DICT: dictionary format is unknown, bad magic number");
|
||||
#ifdef USE_MMAP_FOR_DICTIONARY
|
||||
releaseDictBuf(static_cast<const char *>(dictBuf) - adjust, adjDictSize, fd);
|
||||
|
|
|
@ -64,13 +64,14 @@ class BinaryFormat {
|
|||
static const int UNKNOWN_FORMAT = -1;
|
||||
static const int SHORTCUT_LIST_SIZE_SIZE = 2;
|
||||
|
||||
static int detectFormat(const uint8_t *const dict);
|
||||
static int getHeaderSize(const uint8_t *const dict);
|
||||
static int getFlags(const uint8_t *const dict);
|
||||
static int detectFormat(const uint8_t *const dict, const int dictSize);
|
||||
static int getHeaderSize(const uint8_t *const dict, const int dictSize);
|
||||
static int getFlags(const uint8_t *const dict, const int dictSize);
|
||||
static bool hasBlacklistedOrNotAWordFlag(const int flags);
|
||||
static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue,
|
||||
const int outValueSize);
|
||||
static int readHeaderValueInt(const uint8_t *const dict, const char *const key);
|
||||
static void readHeaderValue(const uint8_t *const dict, const int dictSize,
|
||||
const char *const key, int *outValue, const int outValueSize);
|
||||
static int readHeaderValueInt(const uint8_t *const dict, const int dictSize,
|
||||
const char *const key);
|
||||
static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos);
|
||||
static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos);
|
||||
static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos);
|
||||
|
@ -96,7 +97,7 @@ class BinaryFormat {
|
|||
const uint8_t *bigramFilter, const int unigramProbability);
|
||||
static int getBigramProbabilityFromHashMap(const int position,
|
||||
const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
|
||||
static float getMultiWordCostMultiplier(const uint8_t *const dict);
|
||||
static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize);
|
||||
static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
|
||||
hash_map_compat<int, int> *bigramMap);
|
||||
static int getBigramProbability(const uint8_t *const root, int position,
|
||||
|
@ -122,6 +123,8 @@ class BinaryFormat {
|
|||
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20;
|
||||
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30;
|
||||
|
||||
// Any file smaller than this is not a dictionary.
|
||||
static const int DICTIONARY_MINIMUM_SIZE = 4;
|
||||
// Originally, format version 1 had a 16-bit magic number, then the version number `01'
|
||||
// then options that must be 0. Hence the first 32-bits of the format are always as follow
|
||||
// and it's okay to consider them a magic number as a whole.
|
||||
|
@ -131,6 +134,8 @@ class BinaryFormat {
|
|||
// number, so we had to change it so that version 2 files would be rejected by older
|
||||
// implementations. On this occasion, we made the magic number 32 bits long.
|
||||
static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE
|
||||
// Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12
|
||||
static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12;
|
||||
|
||||
static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1;
|
||||
static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20;
|
||||
|
@ -141,8 +146,11 @@ class BinaryFormat {
|
|||
static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos);
|
||||
};
|
||||
|
||||
AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
|
||||
AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) {
|
||||
// The magic number is stored big-endian.
|
||||
// If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't
|
||||
// understand this format.
|
||||
if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT;
|
||||
const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3];
|
||||
switch (magicNumber) {
|
||||
case FORMAT_VERSION_1_MAGIC_NUMBER:
|
||||
|
@ -152,6 +160,10 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
|
|||
// Options (2 bytes) must be 0x00 0x00
|
||||
return 1;
|
||||
case FORMAT_VERSION_2_MAGIC_NUMBER:
|
||||
// Version 2 dictionaries are at least 12 bytes long (see below details for the header).
|
||||
// If this dictionary has the version 2 magic number but is less than 12 bytes long, then
|
||||
// it's an unknown format and we need to avoid confidently reading the next bytes.
|
||||
if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT;
|
||||
// Format 2 header is as follows:
|
||||
// Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE
|
||||
// Version number (2 bytes) 0x00 0x02
|
||||
|
@ -163,8 +175,8 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
|
|||
}
|
||||
}
|
||||
|
||||
inline int BinaryFormat::getFlags(const uint8_t *const dict) {
|
||||
switch (detectFormat(dict)) {
|
||||
inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) {
|
||||
switch (detectFormat(dict, dictSize)) {
|
||||
case 1:
|
||||
return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else?
|
||||
default:
|
||||
|
@ -176,8 +188,8 @@ inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) {
|
|||
return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
|
||||
}
|
||||
|
||||
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
|
||||
switch (detectFormat(dict)) {
|
||||
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) {
|
||||
switch (detectFormat(dict, dictSize)) {
|
||||
case 1:
|
||||
return FORMAT_VERSION_1_HEADER_SIZE;
|
||||
case 2:
|
||||
|
@ -188,12 +200,12 @@ inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
|
|||
}
|
||||
}
|
||||
|
||||
inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key,
|
||||
int *outValue, const int outValueSize) {
|
||||
inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize,
|
||||
const char *const key, int *outValue, const int outValueSize) {
|
||||
int outValueIndex = 0;
|
||||
// Only format 2 and above have header attributes as {key,value} string pairs. For prior
|
||||
// formats, we just return an empty string, as if the key wasn't found.
|
||||
if (2 <= detectFormat(dict)) {
|
||||
if (2 <= detectFormat(dict, dictSize)) {
|
||||
const int headerOptionsOffset = 4 /* magic number */
|
||||
+ 2 /* dictionary version */ + 2 /* flags */;
|
||||
const int headerSize =
|
||||
|
@ -236,11 +248,12 @@ inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char
|
|||
if (outValueIndex >= 0) outValue[outValueIndex] = 0;
|
||||
}
|
||||
|
||||
inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) {
|
||||
inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize,
|
||||
const char *const key) {
|
||||
const int bufferSize = LARGEST_INT_DIGIT_COUNT;
|
||||
int intBuffer[bufferSize];
|
||||
char charBuffer[bufferSize];
|
||||
BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize);
|
||||
BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize);
|
||||
for (int i = 0; i < bufferSize; ++i) {
|
||||
charBuffer[i] = intBuffer[i];
|
||||
}
|
||||
|
@ -256,8 +269,10 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *
|
|||
return ((msb & 0x7F) << 8) | dict[(*pos)++];
|
||||
}
|
||||
|
||||
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) {
|
||||
const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE");
|
||||
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict,
|
||||
const int dictSize) {
|
||||
const int headerValue = readHeaderValueInt(dict, dictSize,
|
||||
"MULTIPLE_WORDS_DEMOTION_RATE");
|
||||
if (headerValue == S_INT_MIN) {
|
||||
return 1.0f;
|
||||
}
|
||||
|
|
|
@ -34,9 +34,11 @@ namespace latinime {
|
|||
|
||||
Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust)
|
||||
: mDict(static_cast<unsigned char *>(dict)),
|
||||
mOffsetDict((static_cast<unsigned char *>(dict)) + BinaryFormat::getHeaderSize(mDict)),
|
||||
mOffsetDict((static_cast<unsigned char *>(dict))
|
||||
+ BinaryFormat::getHeaderSize(mDict, dictSize)),
|
||||
mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust),
|
||||
mUnigramDictionary(new UnigramDictionary(mOffsetDict, BinaryFormat::getFlags(mDict))),
|
||||
mUnigramDictionary(new UnigramDictionary(mOffsetDict,
|
||||
BinaryFormat::getFlags(mDict, dictSize))),
|
||||
mBigramDictionary(new BigramDictionary(mOffsetDict)),
|
||||
mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())),
|
||||
mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) {
|
||||
|
|
|
@ -64,7 +64,8 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer;
|
|||
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
|
||||
int prevWordLength) {
|
||||
mDictionary = dictionary;
|
||||
mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict());
|
||||
mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(),
|
||||
mDictionary->getDictSize());
|
||||
if (!prevWord) {
|
||||
mPrevWordPos = NOT_VALID_WORD;
|
||||
return;
|
||||
|
|
Loading…
Reference in a new issue