From 3107b467c91c471ce4e00c5d8de559f7b0da2cd6 Mon Sep 17 00:00:00 2001 From: Satoshi Kataoka Date: Mon, 1 Apr 2013 17:57:31 +0900 Subject: [PATCH] Move policy and session to AOSP Bug: 8197301 Change-Id: I742ff0d939f9ad1ee2cd8b959b6c5ee2757fd177 --- native/jni/Android.mk | 7 +- .../suggest/core/dicnode/dic_nodes_cache.cpp | 59 +++++ .../suggest/core/dicnode/dic_nodes_cache.h | 185 +++++++++++++ native/jni/src/suggest/core/policy/scoring.h | 57 ++++ .../src/suggest/core/policy/suggest_policy.h | 39 +++ .../jni/src/suggest/core/policy/traversal.h | 61 +++++ .../jni/src/suggest/core/policy/weighting.cpp | 244 ++++++++++++++++++ .../jni/src/suggest/core/policy/weighting.h | 104 ++++++++ .../core/session/dic_traverse_session.cpp | 106 ++++++++ .../core/session/dic_traverse_session.h | 171 ++++++++++++ 10 files changed, 1032 insertions(+), 1 deletion(-) create mode 100644 native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp create mode 100644 native/jni/src/suggest/core/dicnode/dic_nodes_cache.h create mode 100644 native/jni/src/suggest/core/policy/scoring.h create mode 100644 native/jni/src/suggest/core/policy/suggest_policy.h create mode 100644 native/jni/src/suggest/core/policy/traversal.h create mode 100644 native/jni/src/suggest/core/policy/weighting.cpp create mode 100644 native/jni/src/suggest/core/policy/weighting.h create mode 100644 native/jni/src/suggest/core/session/dic_traverse_session.cpp create mode 100644 native/jni/src/suggest/core/session/dic_traverse_session.h diff --git a/native/jni/Android.mk b/native/jni/Android.mk index 12f99eb52..423c24e88 100644 --- a/native/jni/Android.mk +++ b/native/jni/Android.mk @@ -29,7 +29,9 @@ LATIN_IME_SRC_FULLPATH_DIR := $(LOCAL_PATH)/$(LATIN_IME_SRC_DIR) LOCAL_C_INCLUDES += \ $(LATIN_IME_SRC_FULLPATH_DIR) \ $(LATIN_IME_SRC_FULLPATH_DIR)/suggest \ - $(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/dicnode + $(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/dicnode \ + $(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/policy \ + $(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/session LOCAL_CFLAGS += -Werror -Wall -Wextra -Weffc++ -Wformat=2 -Wcast-qual -Wcast-align \ -Wwrite-strings -Wfloat-equal -Wpointer-arith -Winit-self -Wredundant-decls -Wno-system-headers @@ -63,7 +65,10 @@ LATIN_IME_CORE_SRC_FILES := \ unigram_dictionary.cpp \ words_priority_queue.cpp \ suggest/core/dicnode/dic_node.cpp \ + suggest/core/dicnode/dic_nodes_cache.cpp \ suggest/core/dicnode/dic_node_utils.cpp \ + suggest/core/policy/weighting.cpp \ + suggest/core/session/dic_traverse_session.cpp \ suggest/gesture_suggest.cpp \ suggest/typing_suggest.cpp diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp new file mode 100644 index 000000000..b9a60780b --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "defines.h" +#include "dic_node_priority_queue.h" +#include "dic_node_utils.h" +#include "dic_nodes_cache.h" + +namespace latinime { + +/** + * Truncates all of the dicNodes so that they start at the given commit point. + * Only called for multi-word typing input. + */ +DicNode *DicNodesCache::setCommitPoint(int commitPoint) { + std::list dicNodesList; + while (mCachedDicNodesForContinuousSuggestion->getSize() > 0) { + DicNode dicNode; + mCachedDicNodesForContinuousSuggestion->copyPop(&dicNode); + dicNodesList.push_front(dicNode); + } + + // Get the starting words of the top scoring dicNode (last dicNode popped from priority queue) + // up to the commit point. These words have already been committed to the text view. + DicNode *topDicNode = &dicNodesList.front(); + DicNode topDicNodeCopy; + DicNodeUtils::initByCopy(topDicNode, &topDicNodeCopy); + + // Keep only those dicNodes that match the same starting words. + std::list::iterator iter; + for (iter = dicNodesList.begin(); iter != dicNodesList.end(); iter++) { + DicNode *dicNode = &*iter; + if (dicNode->truncateNode(&topDicNodeCopy, commitPoint)) { + mCachedDicNodesForContinuousSuggestion->copyPush(dicNode); + } else { + // Top dicNode should be reprocessed. + ASSERT(dicNode != topDicNode); + DicNode::managedDelete(dicNode); + } + } + mInputIndex -= commitPoint; + return topDicNode; +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h new file mode 100644 index 000000000..a62aa422a --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h @@ -0,0 +1,185 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODES_CACHE_H +#define LATINIME_DIC_NODES_CACHE_H + +#include + +#include "defines.h" +#include "dic_node_priority_queue.h" + +#define INITIAL_QUEUE_ID_ACTIVE 0 +#define INITIAL_QUEUE_ID_NEXT_ACTIVE 1 +#define INITIAL_QUEUE_ID_TERMINAL 2 +#define INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION 3 +#define PRIORITY_QUEUES_SIZE 4 + +namespace latinime { + +class DicNode; + +/** + * Class for controlling dicNode search priority queue and lexicon trie traversal. + */ +class DicNodesCache { + public: + AK_FORCE_INLINE DicNodesCache() + : mActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_ACTIVE]), + mNextActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_NEXT_ACTIVE]), + mTerminalDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_TERMINAL]), + mCachedDicNodesForContinuousSuggestion( + &mDicNodePriorityQueues[INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION]), + mInputIndex(0), mLastCachedInputIndex(0) { + } + + AK_FORCE_INLINE virtual ~DicNodesCache() {} + + AK_FORCE_INLINE void reset(const int nextActiveSize, const int terminalSize) { + mInputIndex = 0; + mLastCachedInputIndex = 0; + mActiveDicNodes->reset(); + mNextActiveDicNodes->clearAndResize(nextActiveSize); + mTerminalDicNodes->clearAndResize(terminalSize); + mCachedDicNodesForContinuousSuggestion->reset(); + } + + AK_FORCE_INLINE void continueSearch() { + resetTemporaryCaches(); + restoreActiveDicNodesFromCache(); + } + + AK_FORCE_INLINE void advanceActiveDicNodes() { + if (DEBUG_DICT) { + AKLOGI("Advance active %d nodes.", mNextActiveDicNodes->getSize()); + } + if (DEBUG_DICT_FULL) { + mNextActiveDicNodes->dump(); + } + mNextActiveDicNodes = + moveNodesAndReturnReusableEmptyQueue(mNextActiveDicNodes, &mActiveDicNodes); + } + + DicNode *setCommitPoint(int commitPoint); + + int activeSize() const { return mActiveDicNodes->getSize(); } + int terminalSize() const { return mTerminalDicNodes->getSize(); } + bool isLookAheadCorrectionInputIndex(const int inputIndex) const { + return inputIndex == mInputIndex - 1; + } + void advanceInputIndex(const int inputSize) { + if (mInputIndex < inputSize) { + mInputIndex++; + } + } + + AK_FORCE_INLINE void copyPushTerminal(DicNode *dicNode) { + mTerminalDicNodes->copyPush(dicNode); + } + + AK_FORCE_INLINE void copyPushActive(DicNode *dicNode) { + mActiveDicNodes->copyPush(dicNode); + } + + AK_FORCE_INLINE bool copyPushContinue(DicNode *dicNode) { + return mCachedDicNodesForContinuousSuggestion->copyPush(dicNode); + } + + AK_FORCE_INLINE void copyPushNextActive(DicNode *dicNode) { + DicNode *pushedDicNode = mNextActiveDicNodes->copyPush(dicNode); + if (!pushedDicNode) { + if (dicNode->isCached()) { + dicNode->remove(); + } + // We simply drop any dic node that was not cached, ignoring the slim chance + // that one of its children represents what the user really wanted. + } + } + + void popTerminal(DicNode *dest) { + mTerminalDicNodes->copyPop(dest); + } + + void popActive(DicNode *dest) { + mActiveDicNodes->copyPop(dest); + } + + bool hasCachedDicNodesForContinuousSuggestion() const { + return mCachedDicNodesForContinuousSuggestion + && mCachedDicNodesForContinuousSuggestion->getSize() > 0; + } + + AK_FORCE_INLINE bool isCacheBorderForTyping(const int inputSize) const { + // TODO: Move this variable to header + static const int CACHE_BACK_LENGTH = 3; + const int cacheInputIndex = inputSize - CACHE_BACK_LENGTH; + const bool shouldCache = (cacheInputIndex == mInputIndex) + && (cacheInputIndex != mLastCachedInputIndex); + return shouldCache; + } + + AK_FORCE_INLINE void updateLastCachedInputIndex() { + mLastCachedInputIndex = mInputIndex; + } + + private: + DISALLOW_COPY_AND_ASSIGN(DicNodesCache); + + AK_FORCE_INLINE void restoreActiveDicNodesFromCache() { + if (DEBUG_DICT) { + AKLOGI("Restore %d nodes. inputIndex = %d.", + mCachedDicNodesForContinuousSuggestion->getSize(), mLastCachedInputIndex); + } + if (DEBUG_DICT_FULL || DEBUG_CACHE) { + mCachedDicNodesForContinuousSuggestion->dump(); + } + mInputIndex = mLastCachedInputIndex; + mCachedDicNodesForContinuousSuggestion = + moveNodesAndReturnReusableEmptyQueue( + mCachedDicNodesForContinuousSuggestion, &mActiveDicNodes); + } + + AK_FORCE_INLINE static DicNodePriorityQueue *moveNodesAndReturnReusableEmptyQueue( + DicNodePriorityQueue *src, DicNodePriorityQueue **dest) { + const int srcMaxSize = src->getMaxSize(); + const int destMaxSize = (*dest)->getMaxSize(); + DicNodePriorityQueue *tmp = *dest; + *dest = src; + (*dest)->setMaxSize(destMaxSize); + tmp->clearAndResize(srcMaxSize); + return tmp; + } + + AK_FORCE_INLINE void resetTemporaryCaches() { + mActiveDicNodes->clear(); + mNextActiveDicNodes->clear(); + mTerminalDicNodes->clear(); + } + + DicNodePriorityQueue mDicNodePriorityQueues[PRIORITY_QUEUES_SIZE]; + // Active dicNodes currently being expanded. + DicNodePriorityQueue *mActiveDicNodes; + // Next dicNodes to be expanded. + DicNodePriorityQueue *mNextActiveDicNodes; + // Current top terminal dicNodes. + DicNodePriorityQueue *mTerminalDicNodes; + // Cached dicNodes used for continuous suggestion. + DicNodePriorityQueue *mCachedDicNodesForContinuousSuggestion; + int mInputIndex; + int mLastCachedInputIndex; +}; +} // namespace latinime +#endif // LATINIME_DIC_NODES_CACHE_H diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h new file mode 100644 index 000000000..b8c10e25a --- /dev/null +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SCORING_H +#define LATINIME_SCORING_H + +#include "defines.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; + +// This class basically tweaks suggestions and distances apart from CompoundDistance +class Scoring { + public: + virtual int calculateFinalScore(const float compoundDistance, const int inputSize, + const bool forceCommit) const = 0; + virtual bool getMostProbableString( + const DicTraverseSession *const traverseSession, const int terminalSize, + const float languageWeight, int *const outputCodePoints, int *const type, + int *const freq) const = 0; + virtual void safetyNetForMostProbableString(const int terminalSize, + const int maxScore, int *const outputCodePoints, int *const frequencies) const = 0; + // TODO: Make more generic + virtual void searchWordWithDoubleLetter(DicNode *terminals, + const int terminalSize, int *doubleLetterTerminalIndex, + DoubleLetterLevel *doubleLetterLevel) const = 0; + virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, + DicNode *const terminals, const int size) const = 0; + virtual float getDoubleLetterDemotionDistanceCost(const int terminalIndex, + const int doubleLetterTerminalIndex, + const DoubleLetterLevel doubleLetterLevel) const = 0; + virtual bool doesAutoCorrectValidWord() const = 0; + + protected: + Scoring() {} + virtual ~Scoring() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Scoring); +}; +} // namespace latinime +#endif // LATINIME_SCORING_H diff --git a/native/jni/src/suggest/core/policy/suggest_policy.h b/native/jni/src/suggest/core/policy/suggest_policy.h new file mode 100644 index 000000000..885e214f7 --- /dev/null +++ b/native/jni/src/suggest/core/policy/suggest_policy.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SUGGEST_POLICY_H +#define LATINIME_SUGGEST_POLICY_H + +#include "defines.h" + +namespace latinime { +class Traversal; +class Scoring; +class Weighting; + +class SuggestPolicy { + public: + SuggestPolicy() {} + virtual ~SuggestPolicy() {} + virtual const Traversal *getTraversal() const = 0; + virtual const Scoring *getScoring() const = 0; + virtual const Weighting *getWeighting() const = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(SuggestPolicy); +}; +} // namespace latinime +#endif // LATINIME_SUGGEST_POLICY_H diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h new file mode 100644 index 000000000..1d5082ff8 --- /dev/null +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TRAVERSAL_H +#define LATINIME_TRAVERSAL_H + +#include "defines.h" + +namespace latinime { +class Traversal { + public: + virtual int getMaxPointerCount() const = 0; + virtual bool allowsErrorCorrections(const DicNode *const dicNode) const = 0; + virtual bool isOmission(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, const DicNode *const childDicNode) const = 0; + virtual bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool shouldDepthLevelCache(const DicTraverseSession *const traverseSession) const = 0; + virtual bool shouldNodeLevelCache(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual ProximityType getProximityType( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + const DicNode *const childDicNode) const = 0; + virtual bool sameAsTyped(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool needsToTraverseAllUserInput() const = 0; + virtual float getMaxSpatialDistance() const = 0; + virtual bool allowPartialCommit() const = 0; + virtual int getDefaultExpandDicNodeSize() const = 0; + virtual int getMaxCacheSize() const = 0; + virtual bool isPossibleOmissionChildNode( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const = 0; + virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0; + + protected: + Traversal() {} + virtual ~Traversal() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Traversal); +}; +} // namespace latinime +#endif // LATINIME_TRAVERSAL_H diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp new file mode 100644 index 000000000..4d08fa0fa --- /dev/null +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "char_utils.h" +#include "defines.h" +#include "dic_node.h" +#include "dic_node_profiler.h" +#include "dic_node_utils.h" +#include "dic_traverse_session.h" +#include "hash_map_compat.h" +#include "weighting.h" + +namespace latinime { + +static inline void profile(const CorrectionType correctionType, DicNode *const node) { +#if DEBUG_DICT + switch (correctionType) { + case CT_OMISSION: + PROF_OMISSION(node->mProfiler); + return; + case CT_ADDITIONAL_PROXIMITY: + PROF_ADDITIONAL_PROXIMITY(node->mProfiler); + return; + case CT_SUBSTITUTION: + PROF_SUBSTITUTION(node->mProfiler); + return; + case CT_NEW_WORD: + PROF_NEW_WORD(node->mProfiler); + return; + case CT_MATCH: + PROF_MATCH(node->mProfiler); + return; + case CT_COMPLETION: + PROF_COMPLETION(node->mProfiler); + return; + case CT_TERMINAL: + PROF_TERMINAL(node->mProfiler); + return; + case CT_SPACE_SUBSTITUTION: + PROF_SPACE_SUBSTITUTION(node->mProfiler); + return; + case CT_INSERTION: + PROF_INSERTION(node->mProfiler); + return; + case CT_TRANSPOSITION: + PROF_TRANSPOSITION(node->mProfiler); + return; + default: + // do nothing + return; + } +#else + // do nothing +#endif +} + +/* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, DicNode *const dicNode, + hash_map_compat *const bigramCacheMap) { + const int inputSize = traverseSession->getInputSize(); + DicNode_InputStateG inputStateG; + inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default + const float spatialCost = Weighting::getSpatialCost(weighting, correctionType, + traverseSession, parentDicNode, dicNode, &inputStateG); + const float languageCost = Weighting::getLanguageCost(weighting, correctionType, + traverseSession, parentDicNode, dicNode, bigramCacheMap); + const bool edit = Weighting::isEditCorrection(correctionType); + const bool proximity = Weighting::isProximityCorrection(weighting, correctionType, + traverseSession, dicNode); + profile(correctionType, dicNode); + if (inputStateG.mNeedsToUpdateInputStateG) { + dicNode->updateInputIndexG(&inputStateG); + } else { + dicNode->forwardInputIndex(0, getForwardInputCount(correctionType), + (correctionType == CT_TRANSPOSITION)); + } + dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(), + inputSize, edit, proximity); +} + +/* static */ float Weighting::getSpatialCost(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) { + switch(correctionType) { + case CT_OMISSION: + return weighting->getOmissionCost(parentDicNode, dicNode); + case CT_ADDITIONAL_PROXIMITY: + // only used for typing + return weighting->getAdditionalProximityCost(); + case CT_SUBSTITUTION: + // only used for typing + return weighting->getSubstitutionCost(); + case CT_NEW_WORD: + return weighting->getNewWordCost(dicNode); + case CT_MATCH: + return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); + case CT_COMPLETION: + return weighting->getCompletionCost(traverseSession, dicNode); + case CT_TERMINAL: + return weighting->getTerminalSpatialCost(traverseSession, dicNode); + case CT_SPACE_SUBSTITUTION: + return weighting->getSpaceSubstitutionCost(); + case CT_INSERTION: + return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode); + case CT_TRANSPOSITION: + return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode); + default: + return 0.0f; + } +} + +/* static */ float Weighting::getLanguageCost(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + hash_map_compat *const bigramCacheMap) { + switch(correctionType) { + case CT_OMISSION: + return 0.0f; + case CT_SUBSTITUTION: + return 0.0f; + case CT_NEW_WORD: + return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); + case CT_MATCH: + return 0.0f; + case CT_COMPLETION: + return 0.0f; + case CT_TERMINAL: { + const float languageImprobability = + DicNodeUtils::getBigramNodeImprobability( + traverseSession->getOffsetDict(), dicNode, bigramCacheMap); + return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); + } + case CT_SPACE_SUBSTITUTION: + return 0.0f; + case CT_INSERTION: + return 0.0f; + case CT_TRANSPOSITION: + return 0.0f; + default: + return 0.0f; + } +} + +/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) { + switch(correctionType) { + case CT_OMISSION: + return true; + case CT_ADDITIONAL_PROXIMITY: + // Should return true? + return false; + case CT_SUBSTITUTION: + // Should return true? + return false; + case CT_NEW_WORD: + return false; + case CT_MATCH: + return false; + case CT_COMPLETION: + return false; + case CT_TERMINAL: + return false; + case CT_SPACE_SUBSTITUTION: + return false; + case CT_INSERTION: + return true; + case CT_TRANSPOSITION: + return true; + default: + return false; + } +} + +/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) { + switch(correctionType) { + case CT_OMISSION: + return false; + case CT_ADDITIONAL_PROXIMITY: + return false; + case CT_SUBSTITUTION: + return false; + case CT_NEW_WORD: + return false; + case CT_MATCH: + return weighting->isProximityDicNode(traverseSession, dicNode); + case CT_COMPLETION: + return false; + case CT_TERMINAL: + return false; + case CT_SPACE_SUBSTITUTION: + return false; + case CT_INSERTION: + return false; + case CT_TRANSPOSITION: + return false; + default: + return false; + } +} + +/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) { + switch(correctionType) { + case CT_OMISSION: + return 0; + case CT_ADDITIONAL_PROXIMITY: + return 0; + case CT_SUBSTITUTION: + return 0; + case CT_NEW_WORD: + return 0; + case CT_MATCH: + return 1; + case CT_COMPLETION: + return 0; + case CT_TERMINAL: + return 0; + case CT_SPACE_SUBSTITUTION: + return 1; + case CT_INSERTION: + return 2; + case CT_TRANSPOSITION: + return 2; + default: + return 0; + } +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h new file mode 100644 index 000000000..83a0f4b45 --- /dev/null +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_WEIGHTING_H +#define LATINIME_WEIGHTING_H + +#include "defines.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; +struct DicNode_InputStateG; + +class Weighting { + public: + static void addCostAndForwardInputIndex(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, DicNode *const dicNode, + hash_map_compat *const bigramCacheMap); + + protected: + virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + + virtual float getOmissionCost( + const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; + + virtual float getMatchedCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + DicNode_InputStateG *inputStateG) const = 0; + + virtual bool isProximityDicNode(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + + virtual float getTranspositionCost( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const = 0; + + virtual float getInsertionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; + + virtual float getNewWordCost(const DicNode *const dicNode) const = 0; + + virtual float getNewWordBigramCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + hash_map_compat *const bigramCacheMap) const = 0; + + virtual float getCompletionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + + virtual float getTerminalLanguageCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + float dicNodeLanguageImprobability) const = 0; + + virtual bool needsToNormalizeCompoundDistance() const = 0; + + virtual float getAdditionalProximityCost() const = 0; + + virtual float getSubstitutionCost() const = 0; + + virtual float getSpaceSubstitutionCost() const = 0; + + Weighting() {} + virtual ~Weighting() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Weighting); + + static float getSpatialCost(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + DicNode_InputStateG *const inputStateG); + static float getLanguageCost(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + hash_map_compat *const bigramCacheMap); + // TODO: Move to TypingWeighting and GestureWeighting? + static bool isEditCorrection(const CorrectionType correctionType); + // TODO: Move to TypingWeighting and GestureWeighting? + static bool isProximityCorrection(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const dicNode); + // TODO: Move to TypingWeighting and GestureWeighting? + static int getForwardInputCount(const CorrectionType correctionType); +}; +} // namespace latinime +#endif // LATINIME_WEIGHTING_H diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp new file mode 100644 index 000000000..1f781dd43 --- /dev/null +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "defines.h" +#include "dictionary.h" +#include "dic_node_utils.h" +#include "dic_traverse_session.h" +#include "dic_traverse_wrapper.h" +#include "jni.h" + +namespace latinime { + +const int DicTraverseSession::CACHE_START_INPUT_LENGTH_THRESHOLD = 20; + +// A factory method for DicTraverseSession +static void *getSessionInstance(JNIEnv *env, jstring localeStr) { + return new DicTraverseSession(env, localeStr); +} + +// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down. +static void initSessionInstance(void *traverseSession, const Dictionary *const dictionary, + const int *prevWord, const int prevWordLength) { + if (traverseSession) { + DicTraverseSession *tSession = static_cast(traverseSession); + tSession->init(dictionary, prevWord, prevWordLength); + } +} + +// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down. +static void releaseSessionInstance(void *traverseSession) { + delete static_cast(traverseSession); +} + +// An ad-hoc internal class to register the factory method defined above +class TraverseSessionFactoryRegisterer { + public: + TraverseSessionFactoryRegisterer() { + DicTraverseWrapper::setTraverseSessionFactoryMethod(getSessionInstance); + DicTraverseWrapper::setTraverseSessionInitMethod(initSessionInstance); + DicTraverseWrapper::setTraverseSessionReleaseMethod(releaseSessionInstance); + } + private: + DISALLOW_COPY_AND_ASSIGN(TraverseSessionFactoryRegisterer); +}; + +// To invoke the TraverseSessionFactoryRegisterer constructor in the global constructor. +static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer; + +void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, + int prevWordLength) { + mDictionary = dictionary; + if (!prevWord) { + mPrevWordPos = NOT_VALID_WORD; + return; + } + mPrevWordPos = DicNodeUtils::getWordPos(dictionary->getOffsetDict(), prevWord, prevWordLength); +} + +void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, + const int *inputCodePoints, const int inputSize, const int *const inputXs, + const int *const inputYs, const int *const times, const int *const pointerIds, + const float maxSpatialDistance, const int maxPointerCount) { + mProximityInfo = pInfo; + mMaxPointerCount = maxPointerCount; + initializeProximityInfoStates(inputCodePoints, inputXs, inputYs, times, pointerIds, inputSize, + maxSpatialDistance, maxPointerCount); +} + +const uint8_t *DicTraverseSession::getOffsetDict() const { + return mDictionary->getOffsetDict(); +} + +void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { + mDicNodesCache.reset(nextActiveCacheSize, maxWords); + mBigramCacheMap.clear(); + mPartiallyCommited = false; +} + +void DicTraverseSession::initializeProximityInfoStates(const int *const inputCodePoints, + const int *const inputXs, const int *const inputYs, const int *const times, + const int *const pointerIds, const int inputSize, const float maxSpatialDistance, + const int maxPointerCount) { + ASSERT(1 <= maxPointerCount && maxPointerCount <= MAX_POINTER_COUNT_G); + mInputSize = 0; + for (int i = 0; i < maxPointerCount; ++i) { + mProximityInfoStates[i].initInputParams(i, maxSpatialDistance, getProximityInfo(), + inputCodePoints, inputSize, inputXs, inputYs, times, pointerIds, + maxPointerCount == MAX_POINTER_COUNT_G + /* TODO: this is a hack. fix proximity info state */); + mInputSize += mProximityInfoStates[i].size(); + } +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h new file mode 100644 index 000000000..af036f82b --- /dev/null +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_TRAVERSE_SESSION_H +#define LATINIME_DIC_TRAVERSE_SESSION_H + +#include +#include + +#include "defines.h" +#include "dic_nodes_cache.h" +#include "hash_map_compat.h" +#include "jni.h" +#include "proximity_info_state.h" + +namespace latinime { + +class Dictionary; +class ProximityInfo; + +class DicTraverseSession { + public: + AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) + : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), + mDictionary(0), mDicNodesCache(), mBigramCacheMap(), + mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1) { + // NOTE: mProximityInfoStates is an array of instances. + // No need to initialize it explicitly here. + } + + // Non virtual inline destructor -- never inherit this class + AK_FORCE_INLINE ~DicTraverseSession() {} + + void init(const Dictionary *dictionary, const int *prevWord, int prevWordLength); + // TODO: Remove and merge into init + void setupForGetSuggestions(const ProximityInfo *pInfo, const int *inputCodePoints, + const int inputSize, const int *const inputXs, const int *const inputYs, + const int *const times, const int *const pointerIds, const float maxSpatialDistance, + const int maxPointerCount); + void resetCache(const int nextActiveCacheSize, const int maxWords); + + const uint8_t *getOffsetDict() const; + bool canUseCache() const; + + //-------------------- + // getters and setters + //-------------------- + const ProximityInfo *getProximityInfo() const { return mProximityInfo; } + int getPrevWordPos() const { return mPrevWordPos; } + // TODO: REMOVE + void setPrevWordPos(int pos) { mPrevWordPos = pos; } + // TODO: Use proper parameter when changed + int getDicRootPos() const { return 0; } + DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } + hash_map_compat *getBigramCacheMap() { return &mBigramCacheMap; } + const ProximityInfoState *getProximityInfoState(int id) const { + return &mProximityInfoStates[id]; + } + int getInputSize() const { return mInputSize; } + void setPartiallyCommited() { mPartiallyCommited = true; } + bool isPartiallyCommited() const { return mPartiallyCommited; } + + bool isOnlyOnePointerUsed(int *pointerId) const { + // Not in the dictionary word + int usedPointerCount = 0; + int usedPointerId = 0; + for (int i = 0; i < mMaxPointerCount; ++i) { + if (mProximityInfoStates[i].isUsed()) { + ++usedPointerCount; + usedPointerId = i; + } + } + if (usedPointerCount != 1) { + return false; + } + *pointerId = usedPointerId; + return true; + } + + void getSearchKeys(const DicNode *node, std::vector *const outputSearchKeyVector) const { + for (int i = 0; i < MAX_POINTER_COUNT_G; ++i) { + if (!mProximityInfoStates[i].isUsed()) { + continue; + } + const int pointerId = node->getInputIndex(i); + const std::vector *const searchKeyVector = + mProximityInfoStates[i].getSearchKeyVector(pointerId); + outputSearchKeyVector->insert(outputSearchKeyVector->end(), searchKeyVector->begin(), + searchKeyVector->end()); + } + } + + ProximityType getProximityTypeG(const DicNode *const node, const int childCodePoint) const { + ProximityType proximityType = UNRELATED_CHAR; + for (int i = 0; i < MAX_POINTER_COUNT_G; ++i) { + if (!mProximityInfoStates[i].isUsed()) { + continue; + } + const int pointerId = node->getInputIndex(i); + proximityType = mProximityInfoStates[i].getProximityTypeG(pointerId, childCodePoint); + ASSERT(proximityType == UNRELATED_CHAR || proximityType == MATCH_CHAR); + // TODO: Make this more generic + // Currently we assume there are only two types here -- UNRELATED_CHAR + // and MATCH_CHAR + if (proximityType != UNRELATED_CHAR) { + return proximityType; + } + } + return proximityType; + } + + AK_FORCE_INLINE bool isCacheBorderForTyping(const int inputSize) const { + return mDicNodesCache.isCacheBorderForTyping(inputSize); + } + + /** + * Returns whether or not it is possible to continue suggestion from the previous search. + */ + // TODO: Remove. No need to check once the session is fully implemented. + bool isContinuousSuggestionPossible() const { + if (!mDicNodesCache.hasCachedDicNodesForContinuousSuggestion()) { + return false; + } + ASSERT(mMaxPointerCount < MAX_POINTER_COUNT_G); + for (int i = 0; i < mMaxPointerCount; ++i) { + const ProximityInfoState *const pInfoState = getProximityInfoState(i); + // If a proximity info state is not continuous suggestion possible, + // do not continue searching. + if (pInfoState->isUsed() && !pInfoState->isContinuousSuggestionPossible()) { + return false; + } + } + return true; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); + // threshold to start caching + static const int CACHE_START_INPUT_LENGTH_THRESHOLD; + void initializeProximityInfoStates(const int *const inputCodePoints, const int *const inputXs, + const int *const inputYs, const int *const times, const int *const pointerIds, + const int inputSize, const float maxSpatialDistance, const int maxPointerCount); + + int mPrevWordPos; + const ProximityInfo *mProximityInfo; + const Dictionary *mDictionary; + + DicNodesCache mDicNodesCache; + // Temporary cache for bigram frequencies + hash_map_compat mBigramCacheMap; + ProximityInfoState mProximityInfoStates[MAX_POINTER_COUNT_G]; + + int mInputSize; + bool mPartiallyCommited; + int mMaxPointerCount; +}; +} // namespace latinime +#endif // LATINIME_DIC_TRAVERSE_SESSION_H