1 /* 2 * Copyright (C) 2013, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_DYNAMIC_PT_READING_HELPER_H 18 #define LATINIME_DYNAMIC_PT_READING_HELPER_H 19 20 #include <cstddef> 21 #include <vector> 22 23 #include "defines.h" 24 #include "dictionary/structure/pt_common/pt_node_params.h" 25 #include "dictionary/structure/pt_common/pt_node_reader.h" 26 27 namespace latinime { 28 29 class DictionaryShortcutsStructurePolicy; 30 class PtNodeArrayReader; 31 32 /* 33 * This class is used for traversing dynamic patricia trie. This class supports iterating nodes and 34 * dealing with additional buffer. This class counts nodes and node arrays to avoid infinite loop. 35 */ 36 class DynamicPtReadingHelper { 37 public: 38 class TraversingEventListener { 39 public: ~TraversingEventListener()40 virtual ~TraversingEventListener() {}; 41 42 // Returns whether the event handling was succeeded or not. 43 virtual bool onAscend() = 0; 44 45 // Returns whether the event handling was succeeded or not. 46 virtual bool onDescend(const int ptNodeArrayPos) = 0; 47 48 // Returns whether the event handling was succeeded or not. 49 virtual bool onReadingPtNodeArrayTail() = 0; 50 51 // Returns whether the event handling was succeeded or not. 52 virtual bool onVisitingPtNode(const PtNodeParams *const node) = 0; 53 54 protected: TraversingEventListener()55 TraversingEventListener() {}; 56 57 private: 58 DISALLOW_COPY_AND_ASSIGN(TraversingEventListener); 59 }; 60 61 class TraversePolicyToGetAllTerminalPtNodePositions : public TraversingEventListener { 62 public: TraversePolicyToGetAllTerminalPtNodePositions(std::vector<int> * const terminalPositions)63 TraversePolicyToGetAllTerminalPtNodePositions(std::vector<int> *const terminalPositions) 64 : mTerminalPositions(terminalPositions) {} onAscend()65 bool onAscend() { return true; } onDescend(const int ptNodeArrayPos)66 bool onDescend(const int ptNodeArrayPos) { return true; } onReadingPtNodeArrayTail()67 bool onReadingPtNodeArrayTail() { return true; } 68 bool onVisitingPtNode(const PtNodeParams *const ptNodeParams); 69 70 private: 71 DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToGetAllTerminalPtNodePositions); 72 73 std::vector<int> *const mTerminalPositions; 74 }; 75 DynamicPtReadingHelper(const PtNodeReader * const ptNodeReader,const PtNodeArrayReader * const ptNodeArrayReader)76 DynamicPtReadingHelper(const PtNodeReader *const ptNodeReader, 77 const PtNodeArrayReader *const ptNodeArrayReader) 78 : mIsError(false), mReadingState(), mPtNodeReader(ptNodeReader), 79 mPtNodeArrayReader(ptNodeArrayReader), mReadingStateStack() {} 80 ~DynamicPtReadingHelper()81 ~DynamicPtReadingHelper() {} 82 isError()83 AK_FORCE_INLINE bool isError() const { 84 return mIsError; 85 } 86 isEnd()87 AK_FORCE_INLINE bool isEnd() const { 88 return mReadingState.mPos == NOT_A_DICT_POS; 89 } 90 91 // Initialize reading state with the head position of a PtNode array. initWithPtNodeArrayPos(const int ptNodeArrayPos)92 AK_FORCE_INLINE void initWithPtNodeArrayPos(const int ptNodeArrayPos) { 93 if (ptNodeArrayPos == NOT_A_DICT_POS) { 94 mReadingState.mPos = NOT_A_DICT_POS; 95 } else { 96 mIsError = false; 97 mReadingState.mPos = ptNodeArrayPos; 98 mReadingState.mTotalCodePointCountSinceInitialization = 0; 99 mReadingState.mTotalPtNodeIndexInThisArrayChain = 0; 100 mReadingState.mPtNodeArrayIndexInThisArrayChain = 0; 101 mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; 102 mReadingStateStack.clear(); 103 nextPtNodeArray(); 104 } 105 } 106 107 // Initialize reading state with the head position of a node. initWithPtNodePos(const int ptNodePos)108 AK_FORCE_INLINE void initWithPtNodePos(const int ptNodePos) { 109 if (ptNodePos == NOT_A_DICT_POS) { 110 mReadingState.mPos = NOT_A_DICT_POS; 111 } else { 112 mIsError = false; 113 mReadingState.mPos = ptNodePos; 114 mReadingState.mRemainingPtNodeCountInThisArray = 1; 115 mReadingState.mTotalCodePointCountSinceInitialization = 0; 116 mReadingState.mTotalPtNodeIndexInThisArrayChain = 1; 117 mReadingState.mPtNodeArrayIndexInThisArrayChain = 1; 118 mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; 119 mReadingState.mPosOfThisPtNodeArrayHead = NOT_A_DICT_POS; 120 mReadingStateStack.clear(); 121 } 122 } 123 getPtNodeParams()124 AK_FORCE_INLINE const PtNodeParams getPtNodeParams() const { 125 if (isEnd()) { 126 return PtNodeParams(); 127 } 128 return mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(mReadingState.mPos); 129 } 130 isValidTerminalNode(const PtNodeParams & ptNodeParams)131 AK_FORCE_INLINE bool isValidTerminalNode(const PtNodeParams &ptNodeParams) const { 132 return !isEnd() && !ptNodeParams.isDeleted() && ptNodeParams.isTerminal(); 133 } 134 isMatchedCodePoint(const PtNodeParams & ptNodeParams,const int index,const int codePoint)135 AK_FORCE_INLINE bool isMatchedCodePoint(const PtNodeParams &ptNodeParams, const int index, 136 const int codePoint) const { 137 return ptNodeParams.getCodePoints()[index] == codePoint; 138 } 139 140 // Return code point count exclude the last read node's code points. getPrevTotalCodePointCount()141 AK_FORCE_INLINE size_t getPrevTotalCodePointCount() const { 142 return mReadingState.mTotalCodePointCountSinceInitialization; 143 } 144 145 // Return code point count include the last read node's code points. getTotalCodePointCount(const PtNodeParams & ptNodeParams)146 AK_FORCE_INLINE size_t getTotalCodePointCount(const PtNodeParams &ptNodeParams) const { 147 return mReadingState.mTotalCodePointCountSinceInitialization 148 + ptNodeParams.getCodePointCount(); 149 } 150 fetchMergedNodeCodePointsInReverseOrder(const PtNodeParams & ptNodeParams,const int index,int * const outCodePoints)151 AK_FORCE_INLINE void fetchMergedNodeCodePointsInReverseOrder(const PtNodeParams &ptNodeParams, 152 const int index, int *const outCodePoints) const { 153 const int nodeCodePointCount = ptNodeParams.getCodePointCount(); 154 const int *const nodeCodePoints = ptNodeParams.getCodePoints(); 155 for (int i = 0; i < nodeCodePointCount; ++i) { 156 outCodePoints[index + i] = nodeCodePoints[nodeCodePointCount - 1 - i]; 157 } 158 } 159 readNextSiblingNode(const PtNodeParams & ptNodeParams)160 AK_FORCE_INLINE void readNextSiblingNode(const PtNodeParams &ptNodeParams) { 161 mReadingState.mRemainingPtNodeCountInThisArray -= 1; 162 mReadingState.mPos = ptNodeParams.getSiblingNodePos(); 163 if (mReadingState.mRemainingPtNodeCountInThisArray <= 0) { 164 // All nodes in the current node array have been read. 165 followForwardLink(); 166 } 167 } 168 169 // Read the first child node of the current node. readChildNode(const PtNodeParams & ptNodeParams)170 AK_FORCE_INLINE void readChildNode(const PtNodeParams &ptNodeParams) { 171 if (ptNodeParams.hasChildren()) { 172 mReadingState.mTotalCodePointCountSinceInitialization += 173 ptNodeParams.getCodePointCount(); 174 mReadingState.mTotalPtNodeIndexInThisArrayChain = 0; 175 mReadingState.mPtNodeArrayIndexInThisArrayChain = 0; 176 mReadingState.mPos = ptNodeParams.getChildrenPos(); 177 mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; 178 // Read children node array. 179 nextPtNodeArray(); 180 } else { 181 mReadingState.mPos = NOT_A_DICT_POS; 182 } 183 } 184 185 // Read the parent node of the current node. readParentNode(const PtNodeParams & ptNodeParams)186 AK_FORCE_INLINE void readParentNode(const PtNodeParams &ptNodeParams) { 187 if (ptNodeParams.getParentPos() != NOT_A_DICT_POS) { 188 mReadingState.mTotalCodePointCountSinceInitialization += 189 ptNodeParams.getCodePointCount(); 190 mReadingState.mTotalPtNodeIndexInThisArrayChain = 1; 191 mReadingState.mPtNodeArrayIndexInThisArrayChain = 1; 192 mReadingState.mRemainingPtNodeCountInThisArray = 1; 193 mReadingState.mPos = ptNodeParams.getParentPos(); 194 mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; 195 mReadingState.mPosOfThisPtNodeArrayHead = NOT_A_DICT_POS; 196 } else { 197 mReadingState.mPos = NOT_A_DICT_POS; 198 } 199 } 200 getPosOfLastForwardLinkField()201 AK_FORCE_INLINE int getPosOfLastForwardLinkField() const { 202 return mReadingState.mPosOfLastForwardLinkField; 203 } 204 getPosOfLastPtNodeArrayHead()205 AK_FORCE_INLINE int getPosOfLastPtNodeArrayHead() const { 206 return mReadingState.mPosOfThisPtNodeArrayHead; 207 } 208 209 bool traverseAllPtNodesInPostorderDepthFirstManner(TraversingEventListener *const listener); 210 211 bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( 212 TraversingEventListener *const listener); 213 214 int getCodePointsAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints); 215 216 int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length, 217 const bool forceLowerCaseSearch); 218 219 private: 220 DISALLOW_COPY_AND_ASSIGN(DynamicPtReadingHelper); 221 222 // This class encapsulates the reading state of a position in the dictionary. It points at a 223 // specific PtNode in the dictionary. 224 class PtNodeReadingState { 225 public: 226 // Note that copy constructor and assignment operator are used for this class to use 227 // std::vector. PtNodeReadingState()228 PtNodeReadingState() : mPos(NOT_A_DICT_POS), mRemainingPtNodeCountInThisArray(0), 229 mTotalCodePointCountSinceInitialization(0), mTotalPtNodeIndexInThisArrayChain(0), 230 mPtNodeArrayIndexInThisArrayChain(0), mPosOfLastForwardLinkField(NOT_A_DICT_POS), 231 mPosOfThisPtNodeArrayHead(NOT_A_DICT_POS) {} 232 233 int mPos; 234 // Remaining node count in the current array. 235 int mRemainingPtNodeCountInThisArray; 236 size_t mTotalCodePointCountSinceInitialization; 237 // Counter of PtNodes used to avoid infinite loops caused by broken or malicious links. 238 int mTotalPtNodeIndexInThisArrayChain; 239 // Counter of PtNode arrays used to avoid infinite loops caused by cyclic links of empty 240 // PtNode arrays. 241 int mPtNodeArrayIndexInThisArrayChain; 242 int mPosOfLastForwardLinkField; 243 int mPosOfThisPtNodeArrayHead; 244 }; 245 246 static const int MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP; 247 static const int MAX_PT_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP; 248 static const size_t MAX_READING_STATE_STACK_SIZE; 249 250 // TODO: Introduce error code to track what caused the error. 251 bool mIsError; 252 PtNodeReadingState mReadingState; 253 const PtNodeReader *const mPtNodeReader; 254 const PtNodeArrayReader *const mPtNodeArrayReader; 255 std::vector<PtNodeReadingState> mReadingStateStack; 256 257 void nextPtNodeArray(); 258 259 void followForwardLink(); 260 pushReadingStateToStack()261 AK_FORCE_INLINE void pushReadingStateToStack() { 262 if (mReadingStateStack.size() > MAX_READING_STATE_STACK_SIZE) { 263 AKLOGI("Reading state stack overflow. Max size: %zd", MAX_READING_STATE_STACK_SIZE); 264 ASSERT(false); 265 mIsError = true; 266 mReadingState.mPos = NOT_A_DICT_POS; 267 } else { 268 mReadingStateStack.push_back(mReadingState); 269 } 270 } 271 popReadingStateFromStack()272 AK_FORCE_INLINE void popReadingStateFromStack() { 273 if (mReadingStateStack.empty()) { 274 mReadingState.mPos = NOT_A_DICT_POS; 275 } else { 276 mReadingState = mReadingStateStack.back(); 277 mReadingStateStack.pop_back(); 278 } 279 } 280 }; 281 } // namespace latinime 282 #endif /* LATINIME_DYNAMIC_PT_READING_HELPER_H */ 283