1 /*
2  * Copyright (C) 2014, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
18 #define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
19 
20 #include <cstdio>
21 #include <vector>
22 
23 #include "defines.h"
24 #include "dictionary/property/word_attributes.h"
25 #include "dictionary/structure/v4/content/language_model_dict_content_global_counters.h"
26 #include "dictionary/structure/v4/content/probability_entry.h"
27 #include "dictionary/structure/v4/content/terminal_position_lookup_table.h"
28 #include "dictionary/structure/v4/ver4_dict_constants.h"
29 #include "dictionary/utils/entry_counters.h"
30 #include "dictionary/utils/trie_map.h"
31 #include "utils/byte_array_view.h"
32 #include "utils/int_array_view.h"
33 
34 namespace latinime {
35 
36 class HeaderPolicy;
37 
38 /**
39  * Class representing language model.
40  *
41  * This class provides methods to get and store unigram/n-gram probability information and flags.
42  */
43 class LanguageModelDictContent {
44  public:
45     // Pair of word id and probability entry used for iteration.
46     class WordIdAndProbabilityEntry {
47      public:
WordIdAndProbabilityEntry(const int wordId,const ProbabilityEntry & probabilityEntry)48         WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry)
49                 : mWordId(wordId), mProbabilityEntry(probabilityEntry) {}
50 
getWordId()51         int getWordId() const { return mWordId; }
getProbabilityEntry()52         const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; }
53 
54      private:
55         DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry);
56         DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry);
57 
58         const int mWordId;
59         const ProbabilityEntry mProbabilityEntry;
60     };
61 
62     // Iterator.
63     class EntryIterator {
64      public:
EntryIterator(const TrieMap::TrieMapIterator & trieMapIterator,const bool hasHistoricalInfo)65         EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator,
66                 const bool hasHistoricalInfo)
67                 : mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {}
68 
69         const WordIdAndProbabilityEntry operator*() const {
70             const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator;
71             return WordIdAndProbabilityEntry(
72                     result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo));
73         }
74 
75         bool operator!=(const EntryIterator &other) const {
76             return mTrieMapIterator != other.mTrieMapIterator;
77         }
78 
79         const EntryIterator &operator++() {
80             ++mTrieMapIterator;
81             return *this;
82         }
83 
84      private:
85         DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator);
86         DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator);
87 
88         TrieMap::TrieMapIterator mTrieMapIterator;
89         const bool mHasHistoricalInfo;
90     };
91 
92     // Class represents range to use range base for loops.
93     class EntryRange {
94      public:
EntryRange(const TrieMap::TrieMapRange trieMapRange,const bool hasHistoricalInfo)95         EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo)
96                 : mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {}
97 
begin()98         EntryIterator begin() const {
99             return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo);
100         }
101 
end()102         EntryIterator end() const {
103             return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo);
104         }
105 
106      private:
107         DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange);
108         DISALLOW_ASSIGNMENT_OPERATOR(EntryRange);
109 
110         const TrieMap::TrieMapRange mTrieMapRange;
111         const bool mHasHistoricalInfo;
112     };
113 
114     class DumppedFullEntryInfo {
115      public:
DumppedFullEntryInfo(std::vector<int> & prevWordIds,const int targetWordId,const WordAttributes & wordAttributes,const ProbabilityEntry & probabilityEntry)116         DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
117                 const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
118                 : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
119                   mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}
120 
getPrevWordIds()121         const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
getTargetWordId()122         int getTargetWordId() const { return mTargetWordId; }
getWordAttributes()123         const WordAttributes &getWordAttributes() const { return mWordAttributes; }
getProbabilityEntry()124         const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }
125 
126      private:
127         DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);
128 
129         const std::vector<int> mPrevWordIds;
130         const int mTargetWordId;
131         const WordAttributes mWordAttributes;
132         const ProbabilityEntry mProbabilityEntry;
133     };
134 
LanguageModelDictContent(const ReadWriteByteArrayView * const buffers,const bool hasHistoricalInfo)135     LanguageModelDictContent(const ReadWriteByteArrayView *const buffers,
136             const bool hasHistoricalInfo)
137             : mTrieMap(buffers[TRIE_MAP_BUFFER_INDEX]),
138               mGlobalCounters(buffers[GLOBAL_COUNTERS_BUFFER_INDEX]),
139               mHasHistoricalInfo(hasHistoricalInfo) {}
140 
LanguageModelDictContent(const bool hasHistoricalInfo)141     explicit LanguageModelDictContent(const bool hasHistoricalInfo)
142             : mTrieMap(), mGlobalCounters(), mHasHistoricalInfo(hasHistoricalInfo) {}
143 
isNearSizeLimit()144     bool isNearSizeLimit() const {
145         return mTrieMap.isNearSizeLimit() || mGlobalCounters.needsToHalveCounters();
146     }
147 
148     bool save(FILE *const file) const;
149 
150     bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
151             const LanguageModelDictContent *const originalContent);
152 
153     const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
154             const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const;
155 
getProbabilityEntry(const int wordId)156     ProbabilityEntry getProbabilityEntry(const int wordId) const {
157         return getNgramProbabilityEntry(WordIdArrayView(), wordId);
158     }
159 
setProbabilityEntry(const int wordId,const ProbabilityEntry * const probabilityEntry)160     bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) {
161         mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount());
162         return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry);
163     }
164 
removeProbabilityEntry(const int wordId)165     bool removeProbabilityEntry(const int wordId) {
166         return removeNgramProbabilityEntry(WordIdArrayView(), wordId);
167     }
168 
169     ProbabilityEntry getNgramProbabilityEntry(const WordIdArrayView prevWordIds,
170             const int wordId) const;
171 
172     bool setNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId,
173             const ProbabilityEntry *const probabilityEntry);
174 
175     bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
176 
177     EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
178 
179     std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
180             const HeaderPolicy *const headerPolicy, const int wordId) const;
181 
updateAllProbabilityEntriesForGC(const HeaderPolicy * const headerPolicy,MutableEntryCounters * const outEntryCounters)182     bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
183             MutableEntryCounters *const outEntryCounters) {
184         if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
185                 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(),
186                 outEntryCounters)) {
187             return false;
188         }
189         if (mGlobalCounters.needsToHalveCounters()) {
190             mGlobalCounters.halveCounters();
191         }
192         return true;
193     }
194 
195     // entryCounts should be created by updateAllProbabilityEntries.
196     bool truncateEntries(const EntryCounts &currentEntryCounts, const EntryCounts &maxEntryCounts,
197             const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
198 
199     bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
200             const bool isValid, const HistoricalInfo historicalInfo,
201             const HeaderPolicy *const headerPolicy,
202             MutableEntryCounters *const entryCountersToUpdate);
203 
204  private:
205     DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
206 
207     class EntryInfoToTurncate {
208      public:
209         class Comparator {
210          public:
211             bool operator()(const EntryInfoToTurncate &left,
212                     const EntryInfoToTurncate &right) const;
213          private:
214             DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
215         };
216 
217         EntryInfoToTurncate(const int priority, const int count, const int key,
218                 const int prevWordCount, const int *const prevWordIds);
219 
220         int mPriority;
221         // TODO: Remove.
222         int mCount;
223         int mKey;
224         int mPrevWordCount;
225         int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
226 
227      private:
228         DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
229     };
230 
231     static const int TRIE_MAP_BUFFER_INDEX;
232     static const int GLOBAL_COUNTERS_BUFFER_INDEX;
233 
234     TrieMap mTrieMap;
235     LanguageModelDictContentGlobalCounters mGlobalCounters;
236     const bool mHasHistoricalInfo;
237 
238     bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
239             const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex);
240     int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
241     int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
242     bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
243             const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters,
244             MutableEntryCounters *const outEntryCounters);
245     bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
246             const int maxEntryCount, const int targetLevel, int *const outEntryCount);
247     bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
248             const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
249             std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
250     const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
251             const bool isValid, const HistoricalInfo historicalInfo,
252             const HeaderPolicy *const headerPolicy) const;
253     void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
254             const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
255             std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
256 };
257 } // namespace latinime
258 #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
259