Merge "Add methods to inverse compute the probability."
This commit is contained in:
commit
554b85845c
5 changed files with 48 additions and 34 deletions
|
@ -66,6 +66,7 @@ class BinaryFormat {
|
||||||
const int length);
|
const int length);
|
||||||
static int getWordAtAddress(const uint8_t* const root, const int address, const int maxDepth,
|
static int getWordAtAddress(const uint8_t* const root, const int address, const int maxDepth,
|
||||||
uint16_t* outWord);
|
uint16_t* outWord);
|
||||||
|
static int getProbability(const int bigramListPosition, const int unigramFreq);
|
||||||
|
|
||||||
// Flags for special processing
|
// Flags for special processing
|
||||||
// Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or
|
// Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or
|
||||||
|
@ -517,6 +518,14 @@ inline int BinaryFormat::getWordAtAddress(const uint8_t* const root, const int a
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This should probably return a probability in log space.
|
||||||
|
inline int BinaryFormat::getProbability(const int bigramListPosition, const int unigramFreq) {
|
||||||
|
// TODO: use the bigram list position to get the bigram probability. If the bigram
|
||||||
|
// is not found, use the unigram frequency.
|
||||||
|
// TODO: if the unigram frequency is used, compute the actual probability
|
||||||
|
return unigramFreq;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace latinime
|
} // namespace latinime
|
||||||
|
|
||||||
#endif // LATINIME_BINARY_FORMAT_H
|
#endif // LATINIME_BINARY_FORMAT_H
|
||||||
|
|
|
@ -165,28 +165,28 @@ int Correction::getFreqForSplitMultipleWords(const int *freqArray, const int *wo
|
||||||
wordCount, this, isSpaceProximity, word);
|
wordCount, this, isSpaceProximity, word);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLength) {
|
int Correction::getFinalProbability(const int probability, unsigned short **word, int *wordLength) {
|
||||||
return getFinalFreqInternal(freq, word, wordLength, mInputLength);
|
return getFinalProbabilityInternal(probability, word, wordLength, mInputLength);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Correction::getFinalFreqForSubQueue(const int freq, unsigned short **word, int *wordLength,
|
int Correction::getFinalProbabilityForSubQueue(const int probability, unsigned short **word,
|
||||||
const int inputLength) {
|
int *wordLength, const int inputLength) {
|
||||||
return getFinalFreqInternal(freq, word, wordLength, inputLength);
|
return getFinalProbabilityInternal(probability, word, wordLength, inputLength);
|
||||||
}
|
}
|
||||||
|
|
||||||
int Correction::getFinalFreqInternal(const int freq, unsigned short **word, int *wordLength,
|
int Correction::getFinalProbabilityInternal(const int probability, unsigned short **word,
|
||||||
const int inputLength) {
|
int *wordLength, const int inputLength) {
|
||||||
const int outputIndex = mTerminalOutputIndex;
|
const int outputIndex = mTerminalOutputIndex;
|
||||||
const int inputIndex = mTerminalInputIndex;
|
const int inputIndex = mTerminalInputIndex;
|
||||||
*wordLength = outputIndex + 1;
|
*wordLength = outputIndex + 1;
|
||||||
if (outputIndex < MIN_SUGGEST_DEPTH) {
|
if (outputIndex < MIN_SUGGEST_DEPTH) {
|
||||||
return NOT_A_FREQUENCY;
|
return NOT_A_PROBABILITY;
|
||||||
}
|
}
|
||||||
|
|
||||||
*word = mWord;
|
*word = mWord;
|
||||||
int finalFreq = Correction::RankingAlgorithm::calculateFinalFreq(
|
int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability(
|
||||||
inputIndex, outputIndex, freq, mEditDistanceTable, this, inputLength);
|
inputIndex, outputIndex, probability, mEditDistanceTable, this, inputLength);
|
||||||
return finalFreq;
|
return finalProbability;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Correction::initProcessState(const int outputIndex) {
|
bool Correction::initProcessState(const int outputIndex) {
|
||||||
|
@ -649,8 +649,8 @@ inline static bool isUpperCase(unsigned short c) {
|
||||||
//////////////////////
|
//////////////////////
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int outputIndex,
|
int Correction::RankingAlgorithm::calculateFinalProbability(const int inputIndex,
|
||||||
const int freq, int* editDistanceTable, const Correction* correction,
|
const int outputIndex, const int freq, int* editDistanceTable, const Correction* correction,
|
||||||
const int inputLength) {
|
const int inputLength) {
|
||||||
const int excessivePos = correction->getExcessivePos();
|
const int excessivePos = correction->getExcessivePos();
|
||||||
const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER;
|
const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER;
|
||||||
|
|
|
@ -132,9 +132,9 @@ class Correction {
|
||||||
int getFreqForSplitMultipleWords(
|
int getFreqForSplitMultipleWords(
|
||||||
const int *freqArray, const int *wordLengthArray, const int wordCount,
|
const int *freqArray, const int *wordLengthArray, const int wordCount,
|
||||||
const bool isSpaceProximity, const unsigned short *word);
|
const bool isSpaceProximity, const unsigned short *word);
|
||||||
int getFinalFreq(const int freq, unsigned short **word, int* wordLength);
|
int getFinalProbability(const int probability, unsigned short **word, int* wordLength);
|
||||||
int getFinalFreqForSubQueue(const int freq, unsigned short **word, int* wordLength,
|
int getFinalProbabilityForSubQueue(const int probability, unsigned short **word,
|
||||||
const int inputLength);
|
int* wordLength, const int inputLength);
|
||||||
|
|
||||||
CorrectionType processCharAndCalcState(const int32_t c, const bool isTerminal);
|
CorrectionType processCharAndCalcState(const int32_t c, const bool isTerminal);
|
||||||
|
|
||||||
|
@ -156,8 +156,8 @@ class Correction {
|
||||||
|
|
||||||
class RankingAlgorithm {
|
class RankingAlgorithm {
|
||||||
public:
|
public:
|
||||||
static int calculateFinalFreq(const int inputIndex, const int depth,
|
static int calculateFinalProbability(const int inputIndex, const int depth,
|
||||||
const int freq, int *editDistanceTable, const Correction* correction,
|
const int probability, int *editDistanceTable, const Correction* correction,
|
||||||
const int inputLength);
|
const int inputLength);
|
||||||
static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
|
static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
|
||||||
const int wordCount, const Correction* correction, const bool isSpaceProximity,
|
const int wordCount, const Correction* correction, const bool isSpaceProximity,
|
||||||
|
@ -182,8 +182,8 @@ class Correction {
|
||||||
const int32_t c, const bool isTerminal, const bool inputIndexIncremented);
|
const int32_t c, const bool isTerminal, const bool inputIndexIncremented);
|
||||||
inline CorrectionType processUnrelatedCorrectionType();
|
inline CorrectionType processUnrelatedCorrectionType();
|
||||||
inline void addCharToCurrentWord(const int32_t c);
|
inline void addCharToCurrentWord(const int32_t c);
|
||||||
inline int getFinalFreqInternal(const int freq, unsigned short **word, int* wordLength,
|
inline int getFinalProbabilityInternal(const int probability, unsigned short **word,
|
||||||
const int inputLength);
|
int* wordLength, const int inputLength);
|
||||||
|
|
||||||
const int TYPED_LETTER_MULTIPLIER;
|
const int TYPED_LETTER_MULTIPLIER;
|
||||||
const int FULL_WORD_MULTIPLIER;
|
const int FULL_WORD_MULTIPLIER;
|
||||||
|
|
|
@ -172,7 +172,7 @@ static inline void prof_out(void) {
|
||||||
#define PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO -3
|
#define PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO -3
|
||||||
#define ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO -4
|
#define ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO -4
|
||||||
#define NOT_AN_INDEX -1
|
#define NOT_AN_INDEX -1
|
||||||
#define NOT_A_FREQUENCY -1
|
#define NOT_A_PROBABILITY -1
|
||||||
|
|
||||||
#define KEYCODE_SPACE ' '
|
#define KEYCODE_SPACE ' '
|
||||||
|
|
||||||
|
|
|
@ -349,7 +349,7 @@ void UnigramDictionary::getSuggestionCandidates(const bool useFullEditDistance,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void UnigramDictionary::onTerminal(const int freq,
|
inline void UnigramDictionary::onTerminal(const int probability,
|
||||||
const TerminalAttributes& terminalAttributes, Correction *correction,
|
const TerminalAttributes& terminalAttributes, Correction *correction,
|
||||||
WordsPriorityQueuePool *queuePool, const bool addToMasterQueue,
|
WordsPriorityQueuePool *queuePool, const bool addToMasterQueue,
|
||||||
const int currentWordIndex) {
|
const int currentWordIndex) {
|
||||||
|
@ -361,26 +361,28 @@ inline void UnigramDictionary::onTerminal(const int freq,
|
||||||
|
|
||||||
if ((currentWordIndex == FIRST_WORD_INDEX) && addToMasterQueue) {
|
if ((currentWordIndex == FIRST_WORD_INDEX) && addToMasterQueue) {
|
||||||
WordsPriorityQueue *masterQueue = queuePool->getMasterQueue();
|
WordsPriorityQueue *masterQueue = queuePool->getMasterQueue();
|
||||||
const int finalFreq = correction->getFinalFreq(freq, &wordPointer, &wordLength);
|
const int finalProbability =
|
||||||
if (finalFreq != NOT_A_FREQUENCY) {
|
correction->getFinalProbability(probability, &wordPointer, &wordLength);
|
||||||
addWord(wordPointer, wordLength, finalFreq, masterQueue);
|
if (finalProbability != NOT_A_PROBABILITY) {
|
||||||
|
addWord(wordPointer, wordLength, finalProbability, masterQueue);
|
||||||
|
|
||||||
const int shortcutFreq = finalFreq > 0 ? finalFreq - 1 : 0;
|
const int shortcutProbability = finalProbability > 0 ? finalProbability - 1 : 0;
|
||||||
// Please note that the shortcut candidates will be added to the master queue only.
|
// Please note that the shortcut candidates will be added to the master queue only.
|
||||||
TerminalAttributes::ShortcutIterator iterator =
|
TerminalAttributes::ShortcutIterator iterator =
|
||||||
terminalAttributes.getShortcutIterator();
|
terminalAttributes.getShortcutIterator();
|
||||||
while (iterator.hasNextShortcutTarget()) {
|
while (iterator.hasNextShortcutTarget()) {
|
||||||
// TODO: addWord only supports weak ordering, meaning we have no means
|
// TODO: addWord only supports weak ordering, meaning we have no means
|
||||||
// to control the order of the shortcuts relative to one another or to the word.
|
// to control the order of the shortcuts relative to one another or to the word.
|
||||||
// We need to either modulate the frequency of each shortcut according
|
// We need to either modulate the probability of each shortcut according
|
||||||
// to its own shortcut frequency or to make the queue
|
// to its own shortcut probability or to make the queue
|
||||||
// so that the insert order is protected inside the queue for words
|
// so that the insert order is protected inside the queue for words
|
||||||
// with the same score. For the moment we use -1 to make sure the shortcut will
|
// with the same score. For the moment we use -1 to make sure the shortcut will
|
||||||
// never be in front of the word.
|
// never be in front of the word.
|
||||||
uint16_t shortcutTarget[MAX_WORD_LENGTH_INTERNAL];
|
uint16_t shortcutTarget[MAX_WORD_LENGTH_INTERNAL];
|
||||||
const int shortcutTargetStringLength = iterator.getNextShortcutTarget(
|
const int shortcutTargetStringLength = iterator.getNextShortcutTarget(
|
||||||
MAX_WORD_LENGTH_INTERNAL, shortcutTarget);
|
MAX_WORD_LENGTH_INTERNAL, shortcutTarget);
|
||||||
addWord(shortcutTarget, shortcutTargetStringLength, shortcutFreq, masterQueue);
|
addWord(shortcutTarget, shortcutTargetStringLength, shortcutProbability,
|
||||||
|
masterQueue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -393,9 +395,9 @@ inline void UnigramDictionary::onTerminal(const int freq,
|
||||||
if (!subQueue) {
|
if (!subQueue) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int finalFreq = correction->getFinalFreqForSubQueue(freq, &wordPointer, &wordLength,
|
const int finalProbability = correction->getFinalProbabilityForSubQueue(
|
||||||
inputIndex);
|
probability, &wordPointer, &wordLength, inputIndex);
|
||||||
addWord(wordPointer, wordLength, finalFreq, subQueue);
|
addWord(wordPointer, wordLength, finalProbability, subQueue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -762,6 +764,8 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos,
|
||||||
correction->checkState();
|
correction->checkState();
|
||||||
}
|
}
|
||||||
int pos = initialPos;
|
int pos = initialPos;
|
||||||
|
// TODO: get this as an argument
|
||||||
|
const int bigramListPosition = 0;
|
||||||
|
|
||||||
// Flags contain the following information:
|
// Flags contain the following information:
|
||||||
// - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits:
|
// - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits:
|
||||||
|
@ -834,11 +838,12 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos,
|
||||||
if (isTerminalNode) {
|
if (isTerminalNode) {
|
||||||
// The frequency should be here, because we come here only if this is actually
|
// The frequency should be here, because we come here only if this is actually
|
||||||
// a terminal node, and we are on its last char.
|
// a terminal node, and we are on its last char.
|
||||||
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
|
const int unigramFreq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
|
||||||
const int childrenAddressPos = BinaryFormat::skipFrequency(flags, pos);
|
const int childrenAddressPos = BinaryFormat::skipFrequency(flags, pos);
|
||||||
const int attributesPos = BinaryFormat::skipChildrenPosition(flags, childrenAddressPos);
|
const int attributesPos = BinaryFormat::skipChildrenPosition(flags, childrenAddressPos);
|
||||||
TerminalAttributes terminalAttributes(DICT_ROOT, flags, attributesPos);
|
TerminalAttributes terminalAttributes(DICT_ROOT, flags, attributesPos);
|
||||||
onTerminal(freq, terminalAttributes, correction, queuePool, needsToInvokeOnTerminal,
|
const int probability = BinaryFormat::getProbability(bigramListPosition, unigramFreq);
|
||||||
|
onTerminal(probability, terminalAttributes, correction, queuePool, needsToInvokeOnTerminal,
|
||||||
currentWordIndex);
|
currentWordIndex);
|
||||||
|
|
||||||
// If there are more chars in this node, then this virtual node has children.
|
// If there are more chars in this node, then this virtual node has children.
|
||||||
|
|
Loading…
Reference in a new issue