1 /*
2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LATINIME_DIC_NODE_STATE_SCORING_H
18 #define LATINIME_DIC_NODE_STATE_SCORING_H
19 
20 #include <algorithm>
21 #include <cstdint>
22 
23 #include "defines.h"
24 #include "suggest/core/dictionary/digraph_utils.h"
25 #include "suggest/core/dictionary/error_type_utils.h"
26 
27 namespace latinime {
28 
29 class DicNodeStateScoring {
30  public:
DicNodeStateScoring()31     AK_FORCE_INLINE DicNodeStateScoring()
32             : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER),
33               mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
34               mEditCorrectionCount(0), mProximityCorrectionCount(0), mCompletionCount(0),
35               mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
36               mRawLength(0.0f), mContainedErrorTypes(ErrorTypeUtils::NOT_AN_ERROR),
37               mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) {
38     }
39 
~DicNodeStateScoring()40     ~DicNodeStateScoring() {}
41 
init()42     void init() {
43         mEditCorrectionCount = 0;
44         mProximityCorrectionCount = 0;
45         mCompletionCount = 0;
46         mNormalizedCompoundDistance = 0.0f;
47         mSpatialDistance = 0.0f;
48         mLanguageDistance = 0.0f;
49         mRawLength = 0.0f;
50         mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
51         mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
52         mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING;
53         mContainedErrorTypes = ErrorTypeUtils::NOT_AN_ERROR;
54     }
55 
initByCopy(const DicNodeStateScoring * const scoring)56     AK_FORCE_INLINE void initByCopy(const DicNodeStateScoring *const scoring) {
57         mEditCorrectionCount = scoring->mEditCorrectionCount;
58         mProximityCorrectionCount = scoring->mProximityCorrectionCount;
59         mCompletionCount = scoring->mCompletionCount;
60         mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance;
61         mSpatialDistance = scoring->mSpatialDistance;
62         mLanguageDistance = scoring->mLanguageDistance;
63         mRawLength = scoring->mRawLength;
64         mDoubleLetterLevel = scoring->mDoubleLetterLevel;
65         mDigraphIndex = scoring->mDigraphIndex;
66         mContainedErrorTypes = scoring->mContainedErrorTypes;
67         mNormalizedCompoundDistanceAfterFirstWord =
68                 scoring->mNormalizedCompoundDistanceAfterFirstWord;
69     }
70 
addCost(const float spatialCost,const float languageCost,const bool doNormalization,const int inputSize,const int totalInputIndex,const ErrorTypeUtils::ErrorType errorType)71     void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
72             const int inputSize, const int totalInputIndex,
73             const ErrorTypeUtils::ErrorType errorType) {
74         addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
75         mContainedErrorTypes = mContainedErrorTypes | errorType;
76         if (ErrorTypeUtils::isEditCorrectionError(errorType)) {
77             ++mEditCorrectionCount;
78         }
79         if (ErrorTypeUtils::isProximityCorrectionError(errorType)) {
80             ++mProximityCorrectionCount;
81         }
82         if (ErrorTypeUtils::isCompletion(errorType)) {
83             ++mCompletionCount;
84         }
85     }
86 
87     // Saves the current normalized distance for space-aware gestures.
88     // See getNormalizedCompoundDistanceAfterFirstWord for details.
saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet()89     void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() {
90         // We get called here after each word. We only want to store the distance after
91         // the first word, so if we already have a distance we skip saving -- hence "IfNoneYet"
92         // in the method name.
93         if (mNormalizedCompoundDistanceAfterFirstWord >= MAX_VALUE_FOR_WEIGHTING) {
94             mNormalizedCompoundDistanceAfterFirstWord = getNormalizedCompoundDistance();
95         }
96     }
97 
addRawLength(const float rawLength)98     void addRawLength(const float rawLength) {
99         mRawLength += rawLength;
100     }
101 
getCompoundDistance()102     float getCompoundDistance() const {
103         return getCompoundDistance(1.0f);
104     }
105 
getCompoundDistance(const float weightOfLangModelVsSpatialModel)106     float getCompoundDistance(
107             const float weightOfLangModelVsSpatialModel) const {
108         return mSpatialDistance
109                 + mLanguageDistance * weightOfLangModelVsSpatialModel;
110     }
111 
getNormalizedCompoundDistance()112     float getNormalizedCompoundDistance() const {
113         return mNormalizedCompoundDistance;
114     }
115 
116     // For space-aware gestures, we store the normalized distance at the char index
117     // that ends the first word of the suggestion. We call this the distance after
118     // first word.
getNormalizedCompoundDistanceAfterFirstWord()119     float getNormalizedCompoundDistanceAfterFirstWord() const {
120         return mNormalizedCompoundDistanceAfterFirstWord;
121     }
122 
getSpatialDistance()123     float getSpatialDistance() const {
124         return mSpatialDistance;
125     }
126 
getLanguageDistance()127     float getLanguageDistance() const {
128         return mLanguageDistance;
129     }
130 
getEditCorrectionCount()131     int16_t getEditCorrectionCount() const {
132         return mEditCorrectionCount;
133     }
134 
getProximityCorrectionCount()135     int16_t getProximityCorrectionCount() const {
136         return mProximityCorrectionCount;
137     }
138 
getCompletionCount()139     int16_t getCompletionCount() const {
140         return mCompletionCount;
141     }
142 
getRawLength()143     float getRawLength() const {
144         return mRawLength;
145     }
146 
getDoubleLetterLevel()147     DoubleLetterLevel getDoubleLetterLevel() const {
148         return mDoubleLetterLevel;
149     }
150 
setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel)151     void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
152         switch(doubleLetterLevel) {
153             case NOT_A_DOUBLE_LETTER:
154                 break;
155             case A_DOUBLE_LETTER:
156                 if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) {
157                     mDoubleLetterLevel = doubleLetterLevel;
158                 }
159                 break;
160             case A_STRONG_DOUBLE_LETTER:
161                 mDoubleLetterLevel = doubleLetterLevel;
162                 break;
163         }
164     }
165 
getDigraphIndex()166     DigraphUtils::DigraphCodePointIndex getDigraphIndex() const {
167         return mDigraphIndex;
168     }
169 
advanceDigraphIndex()170     void advanceDigraphIndex() {
171         switch(mDigraphIndex) {
172             case DigraphUtils::NOT_A_DIGRAPH_INDEX:
173                 mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT;
174                 break;
175             case DigraphUtils::FIRST_DIGRAPH_CODEPOINT:
176                 mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT;
177                 break;
178             case DigraphUtils::SECOND_DIGRAPH_CODEPOINT:
179                 mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
180                 break;
181         }
182     }
183 
getContainedErrorTypes()184     ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
185         return mContainedErrorTypes;
186     }
187 
188  private:
189     DISALLOW_COPY_AND_ASSIGN(DicNodeStateScoring);
190 
191     DoubleLetterLevel mDoubleLetterLevel;
192     DigraphUtils::DigraphCodePointIndex mDigraphIndex;
193 
194     int16_t mEditCorrectionCount;
195     int16_t mProximityCorrectionCount;
196     int16_t mCompletionCount;
197 
198     float mNormalizedCompoundDistance;
199     float mSpatialDistance;
200     float mLanguageDistance;
201     float mRawLength;
202     // All accumulated error types so far
203     ErrorTypeUtils::ErrorType mContainedErrorTypes;
204     float mNormalizedCompoundDistanceAfterFirstWord;
205 
addDistance(float spatialDistance,float languageDistance,bool doNormalization,int inputSize,int totalInputIndex)206     AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
207             bool doNormalization, int inputSize, int totalInputIndex) {
208         mSpatialDistance += spatialDistance;
209         mLanguageDistance += languageDistance;
210         if (!doNormalization) {
211             mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance;
212         } else {
213             mNormalizedCompoundDistance = (mSpatialDistance + mLanguageDistance)
214                     / static_cast<float>(std::max(1, totalInputIndex));
215         }
216     }
217 };
218 } // namespace latinime
219 #endif // LATINIME_DIC_NODE_STATE_SCORING_H
220