1 /*
2  * Copyright (C) 2013, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LATINIME_DYNAMIC_PT_READING_HELPER_H
18 #define LATINIME_DYNAMIC_PT_READING_HELPER_H
19 
20 #include <cstddef>
21 #include <vector>
22 
23 #include "defines.h"
24 #include "dictionary/structure/pt_common/pt_node_params.h"
25 #include "dictionary/structure/pt_common/pt_node_reader.h"
26 
27 namespace latinime {
28 
29 class DictionaryShortcutsStructurePolicy;
30 class PtNodeArrayReader;
31 
32 /*
33  * This class is used for traversing dynamic patricia trie. This class supports iterating nodes and
34  * dealing with additional buffer. This class counts nodes and node arrays to avoid infinite loop.
35  */
36 class DynamicPtReadingHelper {
37  public:
38     class TraversingEventListener {
39      public:
~TraversingEventListener()40         virtual ~TraversingEventListener() {};
41 
42         // Returns whether the event handling was succeeded or not.
43         virtual bool onAscend() = 0;
44 
45         // Returns whether the event handling was succeeded or not.
46         virtual bool onDescend(const int ptNodeArrayPos) = 0;
47 
48         // Returns whether the event handling was succeeded or not.
49         virtual bool onReadingPtNodeArrayTail() = 0;
50 
51         // Returns whether the event handling was succeeded or not.
52         virtual bool onVisitingPtNode(const PtNodeParams *const node) = 0;
53 
54      protected:
TraversingEventListener()55         TraversingEventListener() {};
56 
57      private:
58         DISALLOW_COPY_AND_ASSIGN(TraversingEventListener);
59     };
60 
61     class TraversePolicyToGetAllTerminalPtNodePositions : public TraversingEventListener {
62      public:
TraversePolicyToGetAllTerminalPtNodePositions(std::vector<int> * const terminalPositions)63         TraversePolicyToGetAllTerminalPtNodePositions(std::vector<int> *const terminalPositions)
64                 : mTerminalPositions(terminalPositions) {}
onAscend()65         bool onAscend() { return true; }
onDescend(const int ptNodeArrayPos)66         bool onDescend(const int ptNodeArrayPos) { return true; }
onReadingPtNodeArrayTail()67         bool onReadingPtNodeArrayTail() { return true; }
68         bool onVisitingPtNode(const PtNodeParams *const ptNodeParams);
69 
70      private:
71         DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToGetAllTerminalPtNodePositions);
72 
73         std::vector<int> *const mTerminalPositions;
74     };
75 
DynamicPtReadingHelper(const PtNodeReader * const ptNodeReader,const PtNodeArrayReader * const ptNodeArrayReader)76     DynamicPtReadingHelper(const PtNodeReader *const ptNodeReader,
77             const PtNodeArrayReader *const ptNodeArrayReader)
78             : mIsError(false), mReadingState(), mPtNodeReader(ptNodeReader),
79               mPtNodeArrayReader(ptNodeArrayReader), mReadingStateStack() {}
80 
~DynamicPtReadingHelper()81     ~DynamicPtReadingHelper() {}
82 
isError()83     AK_FORCE_INLINE bool isError() const {
84         return mIsError;
85     }
86 
isEnd()87     AK_FORCE_INLINE bool isEnd() const {
88         return mReadingState.mPos == NOT_A_DICT_POS;
89     }
90 
91     // Initialize reading state with the head position of a PtNode array.
initWithPtNodeArrayPos(const int ptNodeArrayPos)92     AK_FORCE_INLINE void initWithPtNodeArrayPos(const int ptNodeArrayPos) {
93         if (ptNodeArrayPos == NOT_A_DICT_POS) {
94             mReadingState.mPos = NOT_A_DICT_POS;
95         } else {
96             mIsError = false;
97             mReadingState.mPos = ptNodeArrayPos;
98             mReadingState.mTotalCodePointCountSinceInitialization = 0;
99             mReadingState.mTotalPtNodeIndexInThisArrayChain = 0;
100             mReadingState.mPtNodeArrayIndexInThisArrayChain = 0;
101             mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS;
102             mReadingStateStack.clear();
103             nextPtNodeArray();
104         }
105     }
106 
107     // Initialize reading state with the head position of a node.
initWithPtNodePos(const int ptNodePos)108     AK_FORCE_INLINE void initWithPtNodePos(const int ptNodePos) {
109         if (ptNodePos == NOT_A_DICT_POS) {
110             mReadingState.mPos = NOT_A_DICT_POS;
111         } else {
112             mIsError = false;
113             mReadingState.mPos = ptNodePos;
114             mReadingState.mRemainingPtNodeCountInThisArray = 1;
115             mReadingState.mTotalCodePointCountSinceInitialization = 0;
116             mReadingState.mTotalPtNodeIndexInThisArrayChain = 1;
117             mReadingState.mPtNodeArrayIndexInThisArrayChain = 1;
118             mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS;
119             mReadingState.mPosOfThisPtNodeArrayHead = NOT_A_DICT_POS;
120             mReadingStateStack.clear();
121         }
122     }
123 
getPtNodeParams()124     AK_FORCE_INLINE const PtNodeParams getPtNodeParams() const {
125         if (isEnd()) {
126             return PtNodeParams();
127         }
128         return mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(mReadingState.mPos);
129     }
130 
isValidTerminalNode(const PtNodeParams & ptNodeParams)131     AK_FORCE_INLINE bool isValidTerminalNode(const PtNodeParams &ptNodeParams) const {
132         return !isEnd() && !ptNodeParams.isDeleted() && ptNodeParams.isTerminal();
133     }
134 
isMatchedCodePoint(const PtNodeParams & ptNodeParams,const int index,const int codePoint)135     AK_FORCE_INLINE bool isMatchedCodePoint(const PtNodeParams &ptNodeParams, const int index,
136             const int codePoint) const {
137         return ptNodeParams.getCodePoints()[index] == codePoint;
138     }
139 
140     // Return code point count exclude the last read node's code points.
getPrevTotalCodePointCount()141     AK_FORCE_INLINE size_t getPrevTotalCodePointCount() const {
142         return mReadingState.mTotalCodePointCountSinceInitialization;
143     }
144 
145     // Return code point count include the last read node's code points.
getTotalCodePointCount(const PtNodeParams & ptNodeParams)146     AK_FORCE_INLINE size_t getTotalCodePointCount(const PtNodeParams &ptNodeParams) const {
147         return mReadingState.mTotalCodePointCountSinceInitialization
148                 + ptNodeParams.getCodePointCount();
149     }
150 
fetchMergedNodeCodePointsInReverseOrder(const PtNodeParams & ptNodeParams,const int index,int * const outCodePoints)151     AK_FORCE_INLINE void fetchMergedNodeCodePointsInReverseOrder(const PtNodeParams &ptNodeParams,
152             const int index, int *const outCodePoints) const {
153         const int nodeCodePointCount = ptNodeParams.getCodePointCount();
154         const int *const nodeCodePoints = ptNodeParams.getCodePoints();
155         for (int i =  0; i < nodeCodePointCount; ++i) {
156             outCodePoints[index + i] = nodeCodePoints[nodeCodePointCount - 1 - i];
157         }
158     }
159 
readNextSiblingNode(const PtNodeParams & ptNodeParams)160     AK_FORCE_INLINE void readNextSiblingNode(const PtNodeParams &ptNodeParams) {
161         mReadingState.mRemainingPtNodeCountInThisArray -= 1;
162         mReadingState.mPos = ptNodeParams.getSiblingNodePos();
163         if (mReadingState.mRemainingPtNodeCountInThisArray <= 0) {
164             // All nodes in the current node array have been read.
165             followForwardLink();
166         }
167     }
168 
169     // Read the first child node of the current node.
readChildNode(const PtNodeParams & ptNodeParams)170     AK_FORCE_INLINE void readChildNode(const PtNodeParams &ptNodeParams) {
171         if (ptNodeParams.hasChildren()) {
172             mReadingState.mTotalCodePointCountSinceInitialization +=
173                     ptNodeParams.getCodePointCount();
174             mReadingState.mTotalPtNodeIndexInThisArrayChain = 0;
175             mReadingState.mPtNodeArrayIndexInThisArrayChain = 0;
176             mReadingState.mPos = ptNodeParams.getChildrenPos();
177             mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS;
178             // Read children node array.
179             nextPtNodeArray();
180         } else {
181             mReadingState.mPos = NOT_A_DICT_POS;
182         }
183     }
184 
185     // Read the parent node of the current node.
readParentNode(const PtNodeParams & ptNodeParams)186     AK_FORCE_INLINE void readParentNode(const PtNodeParams &ptNodeParams) {
187         if (ptNodeParams.getParentPos() != NOT_A_DICT_POS) {
188             mReadingState.mTotalCodePointCountSinceInitialization +=
189                     ptNodeParams.getCodePointCount();
190             mReadingState.mTotalPtNodeIndexInThisArrayChain = 1;
191             mReadingState.mPtNodeArrayIndexInThisArrayChain = 1;
192             mReadingState.mRemainingPtNodeCountInThisArray = 1;
193             mReadingState.mPos = ptNodeParams.getParentPos();
194             mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS;
195             mReadingState.mPosOfThisPtNodeArrayHead = NOT_A_DICT_POS;
196         } else {
197             mReadingState.mPos = NOT_A_DICT_POS;
198         }
199     }
200 
getPosOfLastForwardLinkField()201     AK_FORCE_INLINE int getPosOfLastForwardLinkField() const {
202         return mReadingState.mPosOfLastForwardLinkField;
203     }
204 
getPosOfLastPtNodeArrayHead()205     AK_FORCE_INLINE int getPosOfLastPtNodeArrayHead() const {
206         return mReadingState.mPosOfThisPtNodeArrayHead;
207     }
208 
209     bool traverseAllPtNodesInPostorderDepthFirstManner(TraversingEventListener *const listener);
210 
211     bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner(
212             TraversingEventListener *const listener);
213 
214     int getCodePointsAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints);
215 
216     int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length,
217             const bool forceLowerCaseSearch);
218 
219  private:
220     DISALLOW_COPY_AND_ASSIGN(DynamicPtReadingHelper);
221 
222     // This class encapsulates the reading state of a position in the dictionary. It points at a
223     // specific PtNode in the dictionary.
224     class PtNodeReadingState {
225      public:
226         // Note that copy constructor and assignment operator are used for this class to use
227         // std::vector.
PtNodeReadingState()228         PtNodeReadingState() : mPos(NOT_A_DICT_POS), mRemainingPtNodeCountInThisArray(0),
229                 mTotalCodePointCountSinceInitialization(0), mTotalPtNodeIndexInThisArrayChain(0),
230                 mPtNodeArrayIndexInThisArrayChain(0), mPosOfLastForwardLinkField(NOT_A_DICT_POS),
231                 mPosOfThisPtNodeArrayHead(NOT_A_DICT_POS) {}
232 
233         int mPos;
234         // Remaining node count in the current array.
235         int mRemainingPtNodeCountInThisArray;
236         size_t mTotalCodePointCountSinceInitialization;
237         // Counter of PtNodes used to avoid infinite loops caused by broken or malicious links.
238         int mTotalPtNodeIndexInThisArrayChain;
239         // Counter of PtNode arrays used to avoid infinite loops caused by cyclic links of empty
240         // PtNode arrays.
241         int mPtNodeArrayIndexInThisArrayChain;
242         int mPosOfLastForwardLinkField;
243         int mPosOfThisPtNodeArrayHead;
244     };
245 
246     static const int MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP;
247     static const int MAX_PT_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP;
248     static const size_t MAX_READING_STATE_STACK_SIZE;
249 
250     // TODO: Introduce error code to track what caused the error.
251     bool mIsError;
252     PtNodeReadingState mReadingState;
253     const PtNodeReader *const mPtNodeReader;
254     const PtNodeArrayReader *const mPtNodeArrayReader;
255     std::vector<PtNodeReadingState> mReadingStateStack;
256 
257     void nextPtNodeArray();
258 
259     void followForwardLink();
260 
pushReadingStateToStack()261     AK_FORCE_INLINE void pushReadingStateToStack() {
262         if (mReadingStateStack.size() > MAX_READING_STATE_STACK_SIZE) {
263             AKLOGI("Reading state stack overflow. Max size: %zd", MAX_READING_STATE_STACK_SIZE);
264             ASSERT(false);
265             mIsError = true;
266             mReadingState.mPos = NOT_A_DICT_POS;
267         } else {
268             mReadingStateStack.push_back(mReadingState);
269         }
270     }
271 
popReadingStateFromStack()272     AK_FORCE_INLINE void popReadingStateFromStack() {
273         if (mReadingStateStack.empty()) {
274             mReadingState.mPos = NOT_A_DICT_POS;
275         } else {
276             mReadingState = mReadingStateStack.back();
277             mReadingStateStack.pop_back();
278         }
279     }
280 };
281 } // namespace latinime
282 #endif /* LATINIME_DYNAMIC_PT_READING_HELPER_H */
283