1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H
18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H
19 
20 #include <android-base/macros.h>
21 #include <sys/mman.h>
22 #include <vndk/hardware_buffer.h>
23 
24 #include <algorithm>
25 #include <map>
26 #include <memory>
27 #include <mutex>
28 #include <set>
29 #include <tuple>
30 #include <unordered_map>
31 #include <utility>
32 #include <vector>
33 
34 #include "CpuExecutor.h"
35 #include "HalInterfaces.h"
36 #include "NeuralNetworks.h"
37 #include "Utils.h"
38 
39 namespace android {
40 namespace nn {
41 
42 class CompilationBuilder;
43 class Device;
44 class ExecutionBurstController;
45 class ModelBuilder;
46 class PreparedModel;
47 
48 // A utility template class to accumulate multiple objects and assign each
49 // a distinct index number, starting with 0.
50 //
51 // The user of this class is responsible for avoiding concurrent calls
52 // to this class from multiple threads.
53 template <typename ObjectType>
54 class ObjectTracker {
55    public:
56     // Adds the object, if it does not already exists.  Returns its index.
57     // The objects should survive the tracker.
add(const ObjectType * object)58     uint32_t add(const ObjectType* object) {
59         VLOG(MEMORY) << __func__ << "(" << SHOW_IF_DEBUG(object) << ")";
60         // See if we already have this object. If so, return its index.
61         auto i = mKnown.find(object);
62         if (i != mKnown.end()) {
63             return i->second;
64         }
65         VLOG(MEMORY) << "It's new";
66         // It's a new one.  Save it an assign an index to it.
67         size_t next = mKnown.size();
68         uint32_t idx = static_cast<uint32_t>(next);
69         mKnown[object] = idx;
70         mObjects.push_back(object);
71         return idx;
72     }
73 
74     // Returns the number of objects contained.
size()75     uint32_t size() const { return mObjects.size(); }
76     // Returns the ith object.
77     const ObjectType* operator[](size_t i) const {
78         CHECK(i < size());
79         return mObjects[i];
80     }
81     // Iteration
begin()82     auto begin() { return mObjects.begin(); }
end()83     auto end() { return mObjects.end(); }
begin()84     auto begin() const { return mObjects.begin(); }
end()85     auto end() const { return mObjects.end(); }
getObjects()86     const std::vector<const ObjectType*>& getObjects() const { return mObjects; }
87 
88    private:
89     // The vector of object pointers we are building.
90     std::vector<const ObjectType*> mObjects;
91     // A faster way to see if we already have an object than doing find().
92     std::unordered_map<const ObjectType*, uint32_t> mKnown;
93 };
94 
95 using CompilationRole = std::tuple<const CompilationBuilder*, IOType, uint32_t>;
96 using StepRoleCallback = std::function<void(const PreparedModel*, IOType, uint32_t)>;
97 
98 struct MemoryDescriptor {
99     std::vector<uint32_t> dimensions;
100     ObjectTracker<PreparedModel> preparedModels;
101     std::vector<hal::BufferRole> inputRoles, outputRoles;
102 };
103 
104 class MemoryValidatorBase {
105     DISALLOW_COPY_AND_ASSIGN(MemoryValidatorBase);
106 
107    public:
108     MemoryValidatorBase() = default;
109     virtual ~MemoryValidatorBase() = default;
110 
111     // Validate the memory usage and size information when passed in
112     // ANeuralNetworks{Model,Compilation}_set*FromMemory.
113     //
114     // This method only validates the arguments against the memory. It does not validate the
115     // correctness of the arguments themselves. E.g. it does not validate if the index is out of
116     // range.
117     //
118     // Usages:
119     //   - ANeuralNetworksModel_setOperandValueFromMemory:
120     //         validate(nullptr, IOType::INPUT, operandIndex, nullptr, offset, length)
121     //
122     //   - ANeuralNetworksExecution_setInputFromMemory:
123     //         validate(compilation, IOType::INPUT, inputIndex, type, offset, length)
124     //
125     //   - ANeuralNetworksExecution_setOutputFromMemory:
126     //         validate(compilation, IOType::OUTPUT, outputIndex, type, offset, length)
127     //
128     virtual bool validate(const CompilationBuilder* compilation, IOType ioType, uint32_t index,
129                           const ANeuralNetworksOperandType* type, uint32_t offset,
130                           uint32_t length) const = 0;
131 
132     // Validate the memory dimensional information at the beginning of a computation.
validateInputDimensions(const std::vector<uint32_t> &)133     virtual bool validateInputDimensions(const std::vector<uint32_t>&) const { return true; }
134 
135     // The validation metadata for this memory.
136     struct Metadata {
137         // The byte size of the memory when it is transformed to a closely packed layout.
138         // Set to 0 if unknown (e.g. non-BLOB mode AHWB or device memory with dynamic shape).
139         uint32_t logicalSize;
140 
141         // The dimensions of the memory. Set to empty if undefined.
142         std::vector<uint32_t> dimensions;
143 
144         // The data type, scale, zero point, and extra parameters of the target operand.
145         // Other fields will be ignored, including dimensions, lifetime, location, etc.
146         // Set to std::nullopt if undefined.
147         std::optional<hal::Operand> operand;
148     };
149     virtual Metadata getMetadata() const = 0;
150 
151     // Try update the memory metadata with the provided metadata. Return false if incompatible.
152     virtual bool updateMetadata(const Metadata& metadata) = 0;
153 
154     // Whether the memory is created with unknown dimensions or rank.
createdWithUnknownShape()155     virtual bool createdWithUnknownShape() const { return false; }
156 
setInitialized(bool)157     virtual void setInitialized(bool) {}
isInitialized()158     virtual bool isInitialized() const { return true; }
159 };
160 
161 int copyIBufferToHidlMemory(const sp<hal::IBuffer>& src, const hal::hidl_memory& dst);
162 
163 int copyHidlMemoryToIBuffer(const hal::hidl_memory& src, const sp<hal::IBuffer>& dst,
164                             const std::vector<uint32_t>& dimensions);
165 
166 // Represents a memory region.
167 class Memory {
168     // Disallow copy and assign to prevent slicing
169     DISALLOW_COPY_AND_ASSIGN(Memory);
170 
171    public:
172     // Custom destructor to notify any ExecutionBurstControllers currently using
173     // this memory that it is being freed.
174     virtual ~Memory();
175 
176     hal::Request::MemoryPool getMemoryPool() const;
getHidlMemory()177     const hal::hidl_memory& getHidlMemory() const { return kHidlMemory; }
getIBuffer()178     const sp<hal::IBuffer>& getIBuffer() const { return kBuffer; }
getSize()179     virtual uint32_t getSize() const { return getHidlMemory().size(); }
180     virtual std::optional<RunTimePoolInfo> getRunTimePoolInfo() const;
181 
getValidator()182     MemoryValidatorBase& getValidator() const {
183         CHECK(mValidator != nullptr);
184         return *mValidator;
185     }
186 
setValidator(std::unique_ptr<MemoryValidatorBase> validator)187     void setValidator(std::unique_ptr<MemoryValidatorBase> validator) {
188         mValidator = std::move(validator);
189     }
190 
191     // Unique key representing this memory object.
192     intptr_t getKey() const;
193 
194     // Marks a burst object as currently using this memory. When this
195     // memory object is destroyed, it will automatically free this memory from
196     // the bursts' memory cache.
197     void usedBy(const std::shared_ptr<ExecutionBurstController>& burst) const;
198 
199     static int copy(const Memory& src, const Memory& dst);
200 
201    protected:
202     Memory(hal::hidl_memory memory);
203     Memory(hal::hidl_memory memory, std::unique_ptr<MemoryValidatorBase> validator);
204     Memory(sp<hal::IBuffer> buffer, uint32_t token);
205 
206     // The HIDL representation for this memory.  We will use one of the following values
207     // when communicating with the drivers.
208     const hal::hidl_memory kHidlMemory;
209     const sp<hal::IBuffer> kBuffer;
210     const uint32_t kToken = 0;
211 
212     std::unique_ptr<MemoryValidatorBase> mValidator;
213 
214    private:
215     mutable std::mutex mMutex;
216     // mUsedBy is essentially a set of burst objects which use this Memory
217     // object. However, std::weak_ptr does not have comparison operations nor a
218     // std::hash implementation. This is because it is either a valid pointer
219     // (non-null) if the shared object is still alive, or it is null if the
220     // object has been freed. To circumvent this, mUsedBy is a map with the raw
221     // pointer as the key and the weak_ptr as the value.
222     mutable std::unordered_map<const ExecutionBurstController*,
223                                std::weak_ptr<ExecutionBurstController>>
224             mUsedBy;
225 
226     mutable std::optional<RunTimePoolInfo> mCachedRunTimePoolInfo;
227     mutable bool mHasCachedRunTimePoolInfo = false;
228 };
229 
230 class MemoryBuilder {
231     DISALLOW_COPY_AND_ASSIGN(MemoryBuilder);
232 
233    public:
234     MemoryBuilder() = default;
235 
236     int addRole(const CompilationBuilder& compilation, IOType ioType, uint32_t index, float freq);
237     int setDimensions(const std::vector<uint32_t>& dimensions);
238 
239     int finish();
240 
241     std::pair<int, std::unique_ptr<Memory>> allocate() const;
242 
243    private:
244     bool badState(const char* name) const;
245 
246     // The memory descriptor that the MemoryBuilder is building.
247     MemoryDescriptor mDesc;
248 
249     // The roles that have been specified via addRole.
250     // This is to check whether a new role has been seen before or not.
251     std::set<CompilationRole> mRoles;
252 
253     // Keep track of the data type, scale, zero point, and extra parameters of the target operand.
254     // Other fields will be ignored, including dimensions, lifetime, location, etc.
255     // It is std::nullopt if no usage has been specified yet.
256     std::optional<hal::Operand> mOperand;
257 
258     // Once the descriptor has been finished, we should not allow further modifications.
259     bool mFinished = false;
260 
261     // The following fields are only valid when finished.
262 
263     // The chosen device to allocate the memory. Set to nullptr if there are multiple devices.
264     const Device* mAllocator = nullptr;
265 
266     // Whether BLOB mode AHWB is supported on all of the relevant devices of the roles.
267     bool mSupportsAhwb = false;
268 
269     // If set to true, allocate() will fallback to Ashmem or AHardwareBuffer if the memory
270     // allocation fails on the chosen device, or if there is no device chosen.
271     bool mShouldFallback = true;
272 };
273 
274 class MemoryAshmem : public Memory {
275    public:
276     // Creates a memory object containing a new android shared memory ("ashmem")
277     // object of the size specified in bytes. Because this ashmem region can be
278     // shared with and accessed by one or more driver processes, MemoryAshmem
279     // has shared ownership over the ashmem region.
280     //
281     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
282     // On error, returns the appropriate NNAPI error code and nullptr.
283     static std::pair<int, std::unique_ptr<MemoryAshmem>> create(uint32_t size);
284 
285     // Get a pointer to the ashmem region of memory. The returned pointer is
286     // valid for the lifetime of the MemoryAshmem object. This call always
287     // returns non-null because it was validated during MemoryAshmem::create.
288     uint8_t* getPointer() const;
289 
getRunTimePoolInfo()290     std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override {
291         return RunTimePoolInfo::createFromExistingBuffer(getPointer(), kHidlMemory.size());
292     }
293 
294     // prefer using MemoryAshmem::create
295     MemoryAshmem(sp<hal::IMemory> mapped, hal::hidl_memory memory);
296 
297    private:
298     const sp<hal::IMemory> kMappedMemory;
299 };
300 
301 class MemoryFd : public Memory {
302    public:
303     // Create a memory object based on input size, prot, and fd that can be sent
304     // across HIDL. This function duplicates the provided fd, and owns the
305     // duplicate.
306     //
307     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
308     // On error, returns the appropriate NNAPI error code and nullptr.
309     static std::pair<int, std::unique_ptr<MemoryFd>> create(size_t size, int prot, int fd,
310                                                             size_t offset);
311 
312     // prefer using MemoryFd::create
313     MemoryFd(hal::hidl_memory memory);
314 };
315 
316 class MemoryAHWB : public Memory {
317    public:
318     // Create a memory object to keep track of (but not take ownership of) the
319     // provided AHardwareBuffer handle.
320     //
321     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
322     // On error, returns the appropriate NNAPI error code and nullptr.
323     static std::pair<int, std::unique_ptr<MemoryAHWB>> create(const AHardwareBuffer& ahwb);
324 
325     // prefer using MemoryAHWB::create
MemoryAHWB(hal::hidl_memory memory,std::unique_ptr<MemoryValidatorBase> validator)326     MemoryAHWB(hal::hidl_memory memory, std::unique_ptr<MemoryValidatorBase> validator)
327         : Memory(std::move(memory), std::move(validator)) {}
328 };
329 
330 class MemoryRuntimeAHWB : public Memory {
331    public:
332     // Create a memory object containing a new BLOB-mode AHardwareBuffer memory
333     // object of the size specified in bytes. The created memory is managed and
334     // owned by the NNAPI runtime.
335     //
336     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
337     // On error, returns the appropriate NNAPI error code and nullptr.
338     static std::pair<int, std::unique_ptr<MemoryRuntimeAHWB>> create(uint32_t size);
339 
340     // Get a pointer to the content of the memory. The returned pointer is
341     // valid for the lifetime of the MemoryRuntimeAHWB object. This call always
342     // returns non-null because it was validated during MemoryRuntimeAHWB::create.
getPointer()343     uint8_t* getPointer() const { return mBuffer; }
344 
getRunTimePoolInfo()345     std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override {
346         return RunTimePoolInfo::createFromExistingBuffer(getPointer(), kHidlMemory.size());
347     }
348 
349     // prefer using MemoryRuntimeAHWB::create
350     MemoryRuntimeAHWB(hal::hidl_memory memory, AHardwareBuffer* ahwb, uint8_t* buffer);
351     ~MemoryRuntimeAHWB();
352 
353    private:
354     AHardwareBuffer* const mAhwb;
355     uint8_t* const mBuffer;
356 };
357 
358 class MemoryFromDevice : public Memory {
359    public:
360     // Create a memory object to keep track of a driver-allocated device memory.
361     // The memory is recognized by the driver via a token.
362     //
363     // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object.
364     // On error, returns the appropriate NNAPI error code and nullptr.
365     static std::pair<int, std::unique_ptr<MemoryFromDevice>> create(sp<hal::IBuffer> buffer,
366                                                                     uint32_t token);
367 
368     // prefer using MemoryFromDevice::create
369     MemoryFromDevice(sp<hal::IBuffer> buffer, uint32_t token);
370 };
371 
372 using MemoryTracker = ObjectTracker<Memory>;
373 
374 }  // namespace nn
375 }  // namespace android
376 
377 #endif  // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H
378