1 /* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_DIC_NODE_STATE_SCORING_H 18 #define LATINIME_DIC_NODE_STATE_SCORING_H 19 20 #include <algorithm> 21 #include <cstdint> 22 23 #include "defines.h" 24 #include "suggest/core/dictionary/digraph_utils.h" 25 #include "suggest/core/dictionary/error_type_utils.h" 26 27 namespace latinime { 28 29 class DicNodeStateScoring { 30 public: DicNodeStateScoring()31 AK_FORCE_INLINE DicNodeStateScoring() 32 : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER), 33 mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), 34 mEditCorrectionCount(0), mProximityCorrectionCount(0), mCompletionCount(0), 35 mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), 36 mRawLength(0.0f), mContainedErrorTypes(ErrorTypeUtils::NOT_AN_ERROR), 37 mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) { 38 } 39 ~DicNodeStateScoring()40 ~DicNodeStateScoring() {} 41 init()42 void init() { 43 mEditCorrectionCount = 0; 44 mProximityCorrectionCount = 0; 45 mCompletionCount = 0; 46 mNormalizedCompoundDistance = 0.0f; 47 mSpatialDistance = 0.0f; 48 mLanguageDistance = 0.0f; 49 mRawLength = 0.0f; 50 mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; 51 mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; 52 mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING; 53 mContainedErrorTypes = ErrorTypeUtils::NOT_AN_ERROR; 54 } 55 initByCopy(const DicNodeStateScoring * const scoring)56 AK_FORCE_INLINE void initByCopy(const DicNodeStateScoring *const scoring) { 57 mEditCorrectionCount = scoring->mEditCorrectionCount; 58 mProximityCorrectionCount = scoring->mProximityCorrectionCount; 59 mCompletionCount = scoring->mCompletionCount; 60 mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance; 61 mSpatialDistance = scoring->mSpatialDistance; 62 mLanguageDistance = scoring->mLanguageDistance; 63 mRawLength = scoring->mRawLength; 64 mDoubleLetterLevel = scoring->mDoubleLetterLevel; 65 mDigraphIndex = scoring->mDigraphIndex; 66 mContainedErrorTypes = scoring->mContainedErrorTypes; 67 mNormalizedCompoundDistanceAfterFirstWord = 68 scoring->mNormalizedCompoundDistanceAfterFirstWord; 69 } 70 addCost(const float spatialCost,const float languageCost,const bool doNormalization,const int inputSize,const int totalInputIndex,const ErrorTypeUtils::ErrorType errorType)71 void addCost(const float spatialCost, const float languageCost, const bool doNormalization, 72 const int inputSize, const int totalInputIndex, 73 const ErrorTypeUtils::ErrorType errorType) { 74 addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex); 75 mContainedErrorTypes = mContainedErrorTypes | errorType; 76 if (ErrorTypeUtils::isEditCorrectionError(errorType)) { 77 ++mEditCorrectionCount; 78 } 79 if (ErrorTypeUtils::isProximityCorrectionError(errorType)) { 80 ++mProximityCorrectionCount; 81 } 82 if (ErrorTypeUtils::isCompletion(errorType)) { 83 ++mCompletionCount; 84 } 85 } 86 87 // Saves the current normalized distance for space-aware gestures. 88 // See getNormalizedCompoundDistanceAfterFirstWord for details. saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet()89 void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() { 90 // We get called here after each word. We only want to store the distance after 91 // the first word, so if we already have a distance we skip saving -- hence "IfNoneYet" 92 // in the method name. 93 if (mNormalizedCompoundDistanceAfterFirstWord >= MAX_VALUE_FOR_WEIGHTING) { 94 mNormalizedCompoundDistanceAfterFirstWord = getNormalizedCompoundDistance(); 95 } 96 } 97 addRawLength(const float rawLength)98 void addRawLength(const float rawLength) { 99 mRawLength += rawLength; 100 } 101 getCompoundDistance()102 float getCompoundDistance() const { 103 return getCompoundDistance(1.0f); 104 } 105 getCompoundDistance(const float weightOfLangModelVsSpatialModel)106 float getCompoundDistance( 107 const float weightOfLangModelVsSpatialModel) const { 108 return mSpatialDistance 109 + mLanguageDistance * weightOfLangModelVsSpatialModel; 110 } 111 getNormalizedCompoundDistance()112 float getNormalizedCompoundDistance() const { 113 return mNormalizedCompoundDistance; 114 } 115 116 // For space-aware gestures, we store the normalized distance at the char index 117 // that ends the first word of the suggestion. We call this the distance after 118 // first word. getNormalizedCompoundDistanceAfterFirstWord()119 float getNormalizedCompoundDistanceAfterFirstWord() const { 120 return mNormalizedCompoundDistanceAfterFirstWord; 121 } 122 getSpatialDistance()123 float getSpatialDistance() const { 124 return mSpatialDistance; 125 } 126 getLanguageDistance()127 float getLanguageDistance() const { 128 return mLanguageDistance; 129 } 130 getEditCorrectionCount()131 int16_t getEditCorrectionCount() const { 132 return mEditCorrectionCount; 133 } 134 getProximityCorrectionCount()135 int16_t getProximityCorrectionCount() const { 136 return mProximityCorrectionCount; 137 } 138 getCompletionCount()139 int16_t getCompletionCount() const { 140 return mCompletionCount; 141 } 142 getRawLength()143 float getRawLength() const { 144 return mRawLength; 145 } 146 getDoubleLetterLevel()147 DoubleLetterLevel getDoubleLetterLevel() const { 148 return mDoubleLetterLevel; 149 } 150 setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel)151 void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { 152 switch(doubleLetterLevel) { 153 case NOT_A_DOUBLE_LETTER: 154 break; 155 case A_DOUBLE_LETTER: 156 if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) { 157 mDoubleLetterLevel = doubleLetterLevel; 158 } 159 break; 160 case A_STRONG_DOUBLE_LETTER: 161 mDoubleLetterLevel = doubleLetterLevel; 162 break; 163 } 164 } 165 getDigraphIndex()166 DigraphUtils::DigraphCodePointIndex getDigraphIndex() const { 167 return mDigraphIndex; 168 } 169 advanceDigraphIndex()170 void advanceDigraphIndex() { 171 switch(mDigraphIndex) { 172 case DigraphUtils::NOT_A_DIGRAPH_INDEX: 173 mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT; 174 break; 175 case DigraphUtils::FIRST_DIGRAPH_CODEPOINT: 176 mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT; 177 break; 178 case DigraphUtils::SECOND_DIGRAPH_CODEPOINT: 179 mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; 180 break; 181 } 182 } 183 getContainedErrorTypes()184 ErrorTypeUtils::ErrorType getContainedErrorTypes() const { 185 return mContainedErrorTypes; 186 } 187 188 private: 189 DISALLOW_COPY_AND_ASSIGN(DicNodeStateScoring); 190 191 DoubleLetterLevel mDoubleLetterLevel; 192 DigraphUtils::DigraphCodePointIndex mDigraphIndex; 193 194 int16_t mEditCorrectionCount; 195 int16_t mProximityCorrectionCount; 196 int16_t mCompletionCount; 197 198 float mNormalizedCompoundDistance; 199 float mSpatialDistance; 200 float mLanguageDistance; 201 float mRawLength; 202 // All accumulated error types so far 203 ErrorTypeUtils::ErrorType mContainedErrorTypes; 204 float mNormalizedCompoundDistanceAfterFirstWord; 205 addDistance(float spatialDistance,float languageDistance,bool doNormalization,int inputSize,int totalInputIndex)206 AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, 207 bool doNormalization, int inputSize, int totalInputIndex) { 208 mSpatialDistance += spatialDistance; 209 mLanguageDistance += languageDistance; 210 if (!doNormalization) { 211 mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance; 212 } else { 213 mNormalizedCompoundDistance = (mSpatialDistance + mLanguageDistance) 214 / static_cast<float>(std::max(1, totalInputIndex)); 215 } 216 } 217 }; 218 } // namespace latinime 219 #endif // LATINIME_DIC_NODE_STATE_SCORING_H 220