diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h index b8ac95250..d5d67c108 100644 --- a/native/jni/src/binary_format.h +++ b/native/jni/src/binary_format.h @@ -66,6 +66,7 @@ class BinaryFormat { const int length); static int getWordAtAddress(const uint8_t* const root, const int address, const int maxDepth, uint16_t* outWord); + static int getProbability(const int bigramListPosition, const int unigramFreq); // Flags for special processing // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or @@ -517,6 +518,14 @@ inline int BinaryFormat::getWordAtAddress(const uint8_t* const root, const int a return 0; } +// This should probably return a probability in log space. +inline int BinaryFormat::getProbability(const int bigramListPosition, const int unigramFreq) { + // TODO: use the bigram list position to get the bigram probability. If the bigram + // is not found, use the unigram frequency. + // TODO: if the unigram frequency is used, compute the actual probability + return unigramFreq; +} + } // namespace latinime #endif // LATINIME_BINARY_FORMAT_H diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp index 087219ed4..376e9a10e 100644 --- a/native/jni/src/correction.cpp +++ b/native/jni/src/correction.cpp @@ -165,28 +165,28 @@ int Correction::getFreqForSplitMultipleWords(const int *freqArray, const int *wo wordCount, this, isSpaceProximity, word); } -int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLength) { - return getFinalFreqInternal(freq, word, wordLength, mInputLength); +int Correction::getFinalProbability(const int probability, unsigned short **word, int *wordLength) { + return getFinalProbabilityInternal(probability, word, wordLength, mInputLength); } -int Correction::getFinalFreqForSubQueue(const int freq, unsigned short **word, int *wordLength, - const int inputLength) { - return getFinalFreqInternal(freq, word, wordLength, inputLength); +int Correction::getFinalProbabilityForSubQueue(const int probability, unsigned short **word, + int *wordLength, const int inputLength) { + return getFinalProbabilityInternal(probability, word, wordLength, inputLength); } -int Correction::getFinalFreqInternal(const int freq, unsigned short **word, int *wordLength, - const int inputLength) { +int Correction::getFinalProbabilityInternal(const int probability, unsigned short **word, + int *wordLength, const int inputLength) { const int outputIndex = mTerminalOutputIndex; const int inputIndex = mTerminalInputIndex; *wordLength = outputIndex + 1; if (outputIndex < MIN_SUGGEST_DEPTH) { - return NOT_A_FREQUENCY; + return NOT_A_PROBABILITY; } *word = mWord; - int finalFreq = Correction::RankingAlgorithm::calculateFinalFreq( - inputIndex, outputIndex, freq, mEditDistanceTable, this, inputLength); - return finalFreq; + int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability( + inputIndex, outputIndex, probability, mEditDistanceTable, this, inputLength); + return finalProbability; } bool Correction::initProcessState(const int outputIndex) { @@ -649,8 +649,8 @@ inline static bool isUpperCase(unsigned short c) { ////////////////////// /* static */ -int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int outputIndex, - const int freq, int* editDistanceTable, const Correction* correction, +int Correction::RankingAlgorithm::calculateFinalProbability(const int inputIndex, + const int outputIndex, const int freq, int* editDistanceTable, const Correction* correction, const int inputLength) { const int excessivePos = correction->getExcessivePos(); const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER; diff --git a/native/jni/src/correction.h b/native/jni/src/correction.h index ee55c9604..1b4e4bf4e 100644 --- a/native/jni/src/correction.h +++ b/native/jni/src/correction.h @@ -132,9 +132,9 @@ class Correction { int getFreqForSplitMultipleWords( const int *freqArray, const int *wordLengthArray, const int wordCount, const bool isSpaceProximity, const unsigned short *word); - int getFinalFreq(const int freq, unsigned short **word, int* wordLength); - int getFinalFreqForSubQueue(const int freq, unsigned short **word, int* wordLength, - const int inputLength); + int getFinalProbability(const int probability, unsigned short **word, int* wordLength); + int getFinalProbabilityForSubQueue(const int probability, unsigned short **word, + int* wordLength, const int inputLength); CorrectionType processCharAndCalcState(const int32_t c, const bool isTerminal); @@ -156,8 +156,8 @@ class Correction { class RankingAlgorithm { public: - static int calculateFinalFreq(const int inputIndex, const int depth, - const int freq, int *editDistanceTable, const Correction* correction, + static int calculateFinalProbability(const int inputIndex, const int depth, + const int probability, int *editDistanceTable, const Correction* correction, const int inputLength); static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, const int wordCount, const Correction* correction, const bool isSpaceProximity, @@ -182,8 +182,8 @@ class Correction { const int32_t c, const bool isTerminal, const bool inputIndexIncremented); inline CorrectionType processUnrelatedCorrectionType(); inline void addCharToCurrentWord(const int32_t c); - inline int getFinalFreqInternal(const int freq, unsigned short **word, int* wordLength, - const int inputLength); + inline int getFinalProbabilityInternal(const int probability, unsigned short **word, + int* wordLength, const int inputLength); const int TYPED_LETTER_MULTIPLIER; const int FULL_WORD_MULTIPLIER; diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index e882c3714..c99f8a8b2 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -172,7 +172,7 @@ static inline void prof_out(void) { #define PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO -3 #define ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO -4 #define NOT_AN_INDEX -1 -#define NOT_A_FREQUENCY -1 +#define NOT_A_PROBABILITY -1 #define KEYCODE_SPACE ' ' diff --git a/native/jni/src/unigram_dictionary.cpp b/native/jni/src/unigram_dictionary.cpp index 05c124b94..a3eda0061 100644 --- a/native/jni/src/unigram_dictionary.cpp +++ b/native/jni/src/unigram_dictionary.cpp @@ -349,7 +349,7 @@ void UnigramDictionary::getSuggestionCandidates(const bool useFullEditDistance, } } -inline void UnigramDictionary::onTerminal(const int freq, +inline void UnigramDictionary::onTerminal(const int probability, const TerminalAttributes& terminalAttributes, Correction *correction, WordsPriorityQueuePool *queuePool, const bool addToMasterQueue, const int currentWordIndex) { @@ -361,26 +361,28 @@ inline void UnigramDictionary::onTerminal(const int freq, if ((currentWordIndex == FIRST_WORD_INDEX) && addToMasterQueue) { WordsPriorityQueue *masterQueue = queuePool->getMasterQueue(); - const int finalFreq = correction->getFinalFreq(freq, &wordPointer, &wordLength); - if (finalFreq != NOT_A_FREQUENCY) { - addWord(wordPointer, wordLength, finalFreq, masterQueue); + const int finalProbability = + correction->getFinalProbability(probability, &wordPointer, &wordLength); + if (finalProbability != NOT_A_PROBABILITY) { + addWord(wordPointer, wordLength, finalProbability, masterQueue); - const int shortcutFreq = finalFreq > 0 ? finalFreq - 1 : 0; + const int shortcutProbability = finalProbability > 0 ? finalProbability - 1 : 0; // Please note that the shortcut candidates will be added to the master queue only. TerminalAttributes::ShortcutIterator iterator = terminalAttributes.getShortcutIterator(); while (iterator.hasNextShortcutTarget()) { // TODO: addWord only supports weak ordering, meaning we have no means // to control the order of the shortcuts relative to one another or to the word. - // We need to either modulate the frequency of each shortcut according - // to its own shortcut frequency or to make the queue + // We need to either modulate the probability of each shortcut according + // to its own shortcut probability or to make the queue // so that the insert order is protected inside the queue for words // with the same score. For the moment we use -1 to make sure the shortcut will // never be in front of the word. uint16_t shortcutTarget[MAX_WORD_LENGTH_INTERNAL]; const int shortcutTargetStringLength = iterator.getNextShortcutTarget( MAX_WORD_LENGTH_INTERNAL, shortcutTarget); - addWord(shortcutTarget, shortcutTargetStringLength, shortcutFreq, masterQueue); + addWord(shortcutTarget, shortcutTargetStringLength, shortcutProbability, + masterQueue); } } } @@ -393,9 +395,9 @@ inline void UnigramDictionary::onTerminal(const int freq, if (!subQueue) { return; } - const int finalFreq = correction->getFinalFreqForSubQueue(freq, &wordPointer, &wordLength, - inputIndex); - addWord(wordPointer, wordLength, finalFreq, subQueue); + const int finalProbability = correction->getFinalProbabilityForSubQueue( + probability, &wordPointer, &wordLength, inputIndex); + addWord(wordPointer, wordLength, finalProbability, subQueue); } } @@ -762,6 +764,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, correction->checkState(); } int pos = initialPos; + // TODO: get this as an argument + const int bigramListPosition = 0; // Flags contain the following information: // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits: @@ -834,11 +838,12 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, if (isTerminalNode) { // The frequency should be here, because we come here only if this is actually // a terminal node, and we are on its last char. - const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); + const int unigramFreq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); const int childrenAddressPos = BinaryFormat::skipFrequency(flags, pos); const int attributesPos = BinaryFormat::skipChildrenPosition(flags, childrenAddressPos); TerminalAttributes terminalAttributes(DICT_ROOT, flags, attributesPos); - onTerminal(freq, terminalAttributes, correction, queuePool, needsToInvokeOnTerminal, + const int probability = BinaryFormat::getProbability(bigramListPosition, unigramFreq); + onTerminal(probability, terminalAttributes, correction, queuePool, needsToInvokeOnTerminal, currentWordIndex); // If there are more chars in this node, then this virtual node has children.