1 /* 2 * Copyright (C) 2014, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H 18 #define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H 19 20 #include <cstdio> 21 #include <vector> 22 23 #include "defines.h" 24 #include "dictionary/property/word_attributes.h" 25 #include "dictionary/structure/v4/content/language_model_dict_content_global_counters.h" 26 #include "dictionary/structure/v4/content/probability_entry.h" 27 #include "dictionary/structure/v4/content/terminal_position_lookup_table.h" 28 #include "dictionary/structure/v4/ver4_dict_constants.h" 29 #include "dictionary/utils/entry_counters.h" 30 #include "dictionary/utils/trie_map.h" 31 #include "utils/byte_array_view.h" 32 #include "utils/int_array_view.h" 33 34 namespace latinime { 35 36 class HeaderPolicy; 37 38 /** 39 * Class representing language model. 40 * 41 * This class provides methods to get and store unigram/n-gram probability information and flags. 42 */ 43 class LanguageModelDictContent { 44 public: 45 // Pair of word id and probability entry used for iteration. 46 class WordIdAndProbabilityEntry { 47 public: WordIdAndProbabilityEntry(const int wordId,const ProbabilityEntry & probabilityEntry)48 WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry) 49 : mWordId(wordId), mProbabilityEntry(probabilityEntry) {} 50 getWordId()51 int getWordId() const { return mWordId; } getProbabilityEntry()52 const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; } 53 54 private: 55 DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry); 56 DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry); 57 58 const int mWordId; 59 const ProbabilityEntry mProbabilityEntry; 60 }; 61 62 // Iterator. 63 class EntryIterator { 64 public: EntryIterator(const TrieMap::TrieMapIterator & trieMapIterator,const bool hasHistoricalInfo)65 EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator, 66 const bool hasHistoricalInfo) 67 : mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {} 68 69 const WordIdAndProbabilityEntry operator*() const { 70 const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator; 71 return WordIdAndProbabilityEntry( 72 result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo)); 73 } 74 75 bool operator!=(const EntryIterator &other) const { 76 return mTrieMapIterator != other.mTrieMapIterator; 77 } 78 79 const EntryIterator &operator++() { 80 ++mTrieMapIterator; 81 return *this; 82 } 83 84 private: 85 DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator); 86 DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator); 87 88 TrieMap::TrieMapIterator mTrieMapIterator; 89 const bool mHasHistoricalInfo; 90 }; 91 92 // Class represents range to use range base for loops. 93 class EntryRange { 94 public: EntryRange(const TrieMap::TrieMapRange trieMapRange,const bool hasHistoricalInfo)95 EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo) 96 : mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {} 97 begin()98 EntryIterator begin() const { 99 return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo); 100 } 101 end()102 EntryIterator end() const { 103 return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo); 104 } 105 106 private: 107 DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange); 108 DISALLOW_ASSIGNMENT_OPERATOR(EntryRange); 109 110 const TrieMap::TrieMapRange mTrieMapRange; 111 const bool mHasHistoricalInfo; 112 }; 113 114 class DumppedFullEntryInfo { 115 public: DumppedFullEntryInfo(std::vector<int> & prevWordIds,const int targetWordId,const WordAttributes & wordAttributes,const ProbabilityEntry & probabilityEntry)116 DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId, 117 const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry) 118 : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId), 119 mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {} 120 getPrevWordIds()121 const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); } getTargetWordId()122 int getTargetWordId() const { return mTargetWordId; } getWordAttributes()123 const WordAttributes &getWordAttributes() const { return mWordAttributes; } getProbabilityEntry()124 const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; } 125 126 private: 127 DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo); 128 129 const std::vector<int> mPrevWordIds; 130 const int mTargetWordId; 131 const WordAttributes mWordAttributes; 132 const ProbabilityEntry mProbabilityEntry; 133 }; 134 LanguageModelDictContent(const ReadWriteByteArrayView * const buffers,const bool hasHistoricalInfo)135 LanguageModelDictContent(const ReadWriteByteArrayView *const buffers, 136 const bool hasHistoricalInfo) 137 : mTrieMap(buffers[TRIE_MAP_BUFFER_INDEX]), 138 mGlobalCounters(buffers[GLOBAL_COUNTERS_BUFFER_INDEX]), 139 mHasHistoricalInfo(hasHistoricalInfo) {} 140 LanguageModelDictContent(const bool hasHistoricalInfo)141 explicit LanguageModelDictContent(const bool hasHistoricalInfo) 142 : mTrieMap(), mGlobalCounters(), mHasHistoricalInfo(hasHistoricalInfo) {} 143 isNearSizeLimit()144 bool isNearSizeLimit() const { 145 return mTrieMap.isNearSizeLimit() || mGlobalCounters.needsToHalveCounters(); 146 } 147 148 bool save(FILE *const file) const; 149 150 bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, 151 const LanguageModelDictContent *const originalContent); 152 153 const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, 154 const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const; 155 getProbabilityEntry(const int wordId)156 ProbabilityEntry getProbabilityEntry(const int wordId) const { 157 return getNgramProbabilityEntry(WordIdArrayView(), wordId); 158 } 159 setProbabilityEntry(const int wordId,const ProbabilityEntry * const probabilityEntry)160 bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) { 161 mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount()); 162 return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); 163 } 164 removeProbabilityEntry(const int wordId)165 bool removeProbabilityEntry(const int wordId) { 166 return removeNgramProbabilityEntry(WordIdArrayView(), wordId); 167 } 168 169 ProbabilityEntry getNgramProbabilityEntry(const WordIdArrayView prevWordIds, 170 const int wordId) const; 171 172 bool setNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId, 173 const ProbabilityEntry *const probabilityEntry); 174 175 bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId); 176 177 EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const; 178 179 std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord( 180 const HeaderPolicy *const headerPolicy, const int wordId) const; 181 updateAllProbabilityEntriesForGC(const HeaderPolicy * const headerPolicy,MutableEntryCounters * const outEntryCounters)182 bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, 183 MutableEntryCounters *const outEntryCounters) { 184 if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), 185 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(), 186 outEntryCounters)) { 187 return false; 188 } 189 if (mGlobalCounters.needsToHalveCounters()) { 190 mGlobalCounters.halveCounters(); 191 } 192 return true; 193 } 194 195 // entryCounts should be created by updateAllProbabilityEntries. 196 bool truncateEntries(const EntryCounts ¤tEntryCounts, const EntryCounts &maxEntryCounts, 197 const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters); 198 199 bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId, 200 const bool isValid, const HistoricalInfo historicalInfo, 201 const HeaderPolicy *const headerPolicy, 202 MutableEntryCounters *const entryCountersToUpdate); 203 204 private: 205 DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); 206 207 class EntryInfoToTurncate { 208 public: 209 class Comparator { 210 public: 211 bool operator()(const EntryInfoToTurncate &left, 212 const EntryInfoToTurncate &right) const; 213 private: 214 DISALLOW_ASSIGNMENT_OPERATOR(Comparator); 215 }; 216 217 EntryInfoToTurncate(const int priority, const int count, const int key, 218 const int prevWordCount, const int *const prevWordIds); 219 220 int mPriority; 221 // TODO: Remove. 222 int mCount; 223 int mKey; 224 int mPrevWordCount; 225 int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; 226 227 private: 228 DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); 229 }; 230 231 static const int TRIE_MAP_BUFFER_INDEX; 232 static const int GLOBAL_COUNTERS_BUFFER_INDEX; 233 234 TrieMap mTrieMap; 235 LanguageModelDictContentGlobalCounters mGlobalCounters; 236 const bool mHasHistoricalInfo; 237 238 bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, 239 const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex); 240 int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); 241 int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; 242 bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, 243 const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters, 244 MutableEntryCounters *const outEntryCounters); 245 bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, 246 const int maxEntryCount, const int targetLevel, int *const outEntryCount); 247 bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, 248 const int bitmapEntryIndex, std::vector<int> *const prevWordIds, 249 std::vector<EntryInfoToTurncate> *const outEntryInfo) const; 250 const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry, 251 const bool isValid, const HistoricalInfo historicalInfo, 252 const HeaderPolicy *const headerPolicy) const; 253 void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy, 254 const int bitmapEntryIndex, std::vector<int> *const prevWordIds, 255 std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const; 256 }; 257 } // namespace latinime 258 #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ 259