1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H 18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H 19 20 #include <android-base/macros.h> 21 #include <sys/mman.h> 22 #include <vndk/hardware_buffer.h> 23 24 #include <algorithm> 25 #include <map> 26 #include <memory> 27 #include <mutex> 28 #include <set> 29 #include <tuple> 30 #include <unordered_map> 31 #include <utility> 32 #include <vector> 33 34 #include "CpuExecutor.h" 35 #include "HalInterfaces.h" 36 #include "NeuralNetworks.h" 37 #include "Utils.h" 38 39 namespace android { 40 namespace nn { 41 42 class CompilationBuilder; 43 class Device; 44 class ExecutionBurstController; 45 class ModelBuilder; 46 class PreparedModel; 47 48 // A utility template class to accumulate multiple objects and assign each 49 // a distinct index number, starting with 0. 50 // 51 // The user of this class is responsible for avoiding concurrent calls 52 // to this class from multiple threads. 53 template <typename ObjectType> 54 class ObjectTracker { 55 public: 56 // Adds the object, if it does not already exists. Returns its index. 57 // The objects should survive the tracker. add(const ObjectType * object)58 uint32_t add(const ObjectType* object) { 59 VLOG(MEMORY) << __func__ << "(" << SHOW_IF_DEBUG(object) << ")"; 60 // See if we already have this object. If so, return its index. 61 auto i = mKnown.find(object); 62 if (i != mKnown.end()) { 63 return i->second; 64 } 65 VLOG(MEMORY) << "It's new"; 66 // It's a new one. Save it an assign an index to it. 67 size_t next = mKnown.size(); 68 uint32_t idx = static_cast<uint32_t>(next); 69 mKnown[object] = idx; 70 mObjects.push_back(object); 71 return idx; 72 } 73 74 // Returns the number of objects contained. size()75 uint32_t size() const { return mObjects.size(); } 76 // Returns the ith object. 77 const ObjectType* operator[](size_t i) const { 78 CHECK(i < size()); 79 return mObjects[i]; 80 } 81 // Iteration begin()82 auto begin() { return mObjects.begin(); } end()83 auto end() { return mObjects.end(); } begin()84 auto begin() const { return mObjects.begin(); } end()85 auto end() const { return mObjects.end(); } getObjects()86 const std::vector<const ObjectType*>& getObjects() const { return mObjects; } 87 88 private: 89 // The vector of object pointers we are building. 90 std::vector<const ObjectType*> mObjects; 91 // A faster way to see if we already have an object than doing find(). 92 std::unordered_map<const ObjectType*, uint32_t> mKnown; 93 }; 94 95 using CompilationRole = std::tuple<const CompilationBuilder*, IOType, uint32_t>; 96 using StepRoleCallback = std::function<void(const PreparedModel*, IOType, uint32_t)>; 97 98 struct MemoryDescriptor { 99 std::vector<uint32_t> dimensions; 100 ObjectTracker<PreparedModel> preparedModels; 101 std::vector<hal::BufferRole> inputRoles, outputRoles; 102 }; 103 104 class MemoryValidatorBase { 105 DISALLOW_COPY_AND_ASSIGN(MemoryValidatorBase); 106 107 public: 108 MemoryValidatorBase() = default; 109 virtual ~MemoryValidatorBase() = default; 110 111 // Validate the memory usage and size information when passed in 112 // ANeuralNetworks{Model,Compilation}_set*FromMemory. 113 // 114 // This method only validates the arguments against the memory. It does not validate the 115 // correctness of the arguments themselves. E.g. it does not validate if the index is out of 116 // range. 117 // 118 // Usages: 119 // - ANeuralNetworksModel_setOperandValueFromMemory: 120 // validate(nullptr, IOType::INPUT, operandIndex, nullptr, offset, length) 121 // 122 // - ANeuralNetworksExecution_setInputFromMemory: 123 // validate(compilation, IOType::INPUT, inputIndex, type, offset, length) 124 // 125 // - ANeuralNetworksExecution_setOutputFromMemory: 126 // validate(compilation, IOType::OUTPUT, outputIndex, type, offset, length) 127 // 128 virtual bool validate(const CompilationBuilder* compilation, IOType ioType, uint32_t index, 129 const ANeuralNetworksOperandType* type, uint32_t offset, 130 uint32_t length) const = 0; 131 132 // Validate the memory dimensional information at the beginning of a computation. validateInputDimensions(const std::vector<uint32_t> &)133 virtual bool validateInputDimensions(const std::vector<uint32_t>&) const { return true; } 134 135 // The validation metadata for this memory. 136 struct Metadata { 137 // The byte size of the memory when it is transformed to a closely packed layout. 138 // Set to 0 if unknown (e.g. non-BLOB mode AHWB or device memory with dynamic shape). 139 uint32_t logicalSize; 140 141 // The dimensions of the memory. Set to empty if undefined. 142 std::vector<uint32_t> dimensions; 143 144 // The data type, scale, zero point, and extra parameters of the target operand. 145 // Other fields will be ignored, including dimensions, lifetime, location, etc. 146 // Set to std::nullopt if undefined. 147 std::optional<hal::Operand> operand; 148 }; 149 virtual Metadata getMetadata() const = 0; 150 151 // Try update the memory metadata with the provided metadata. Return false if incompatible. 152 virtual bool updateMetadata(const Metadata& metadata) = 0; 153 154 // Whether the memory is created with unknown dimensions or rank. createdWithUnknownShape()155 virtual bool createdWithUnknownShape() const { return false; } 156 setInitialized(bool)157 virtual void setInitialized(bool) {} isInitialized()158 virtual bool isInitialized() const { return true; } 159 }; 160 161 int copyIBufferToHidlMemory(const sp<hal::IBuffer>& src, const hal::hidl_memory& dst); 162 163 int copyHidlMemoryToIBuffer(const hal::hidl_memory& src, const sp<hal::IBuffer>& dst, 164 const std::vector<uint32_t>& dimensions); 165 166 // Represents a memory region. 167 class Memory { 168 // Disallow copy and assign to prevent slicing 169 DISALLOW_COPY_AND_ASSIGN(Memory); 170 171 public: 172 // Custom destructor to notify any ExecutionBurstControllers currently using 173 // this memory that it is being freed. 174 virtual ~Memory(); 175 176 hal::Request::MemoryPool getMemoryPool() const; getHidlMemory()177 const hal::hidl_memory& getHidlMemory() const { return kHidlMemory; } getIBuffer()178 const sp<hal::IBuffer>& getIBuffer() const { return kBuffer; } getSize()179 virtual uint32_t getSize() const { return getHidlMemory().size(); } 180 virtual std::optional<RunTimePoolInfo> getRunTimePoolInfo() const; 181 getValidator()182 MemoryValidatorBase& getValidator() const { 183 CHECK(mValidator != nullptr); 184 return *mValidator; 185 } 186 setValidator(std::unique_ptr<MemoryValidatorBase> validator)187 void setValidator(std::unique_ptr<MemoryValidatorBase> validator) { 188 mValidator = std::move(validator); 189 } 190 191 // Unique key representing this memory object. 192 intptr_t getKey() const; 193 194 // Marks a burst object as currently using this memory. When this 195 // memory object is destroyed, it will automatically free this memory from 196 // the bursts' memory cache. 197 void usedBy(const std::shared_ptr<ExecutionBurstController>& burst) const; 198 199 static int copy(const Memory& src, const Memory& dst); 200 201 protected: 202 Memory(hal::hidl_memory memory); 203 Memory(hal::hidl_memory memory, std::unique_ptr<MemoryValidatorBase> validator); 204 Memory(sp<hal::IBuffer> buffer, uint32_t token); 205 206 // The HIDL representation for this memory. We will use one of the following values 207 // when communicating with the drivers. 208 const hal::hidl_memory kHidlMemory; 209 const sp<hal::IBuffer> kBuffer; 210 const uint32_t kToken = 0; 211 212 std::unique_ptr<MemoryValidatorBase> mValidator; 213 214 private: 215 mutable std::mutex mMutex; 216 // mUsedBy is essentially a set of burst objects which use this Memory 217 // object. However, std::weak_ptr does not have comparison operations nor a 218 // std::hash implementation. This is because it is either a valid pointer 219 // (non-null) if the shared object is still alive, or it is null if the 220 // object has been freed. To circumvent this, mUsedBy is a map with the raw 221 // pointer as the key and the weak_ptr as the value. 222 mutable std::unordered_map<const ExecutionBurstController*, 223 std::weak_ptr<ExecutionBurstController>> 224 mUsedBy; 225 226 mutable std::optional<RunTimePoolInfo> mCachedRunTimePoolInfo; 227 mutable bool mHasCachedRunTimePoolInfo = false; 228 }; 229 230 class MemoryBuilder { 231 DISALLOW_COPY_AND_ASSIGN(MemoryBuilder); 232 233 public: 234 MemoryBuilder() = default; 235 236 int addRole(const CompilationBuilder& compilation, IOType ioType, uint32_t index, float freq); 237 int setDimensions(const std::vector<uint32_t>& dimensions); 238 239 int finish(); 240 241 std::pair<int, std::unique_ptr<Memory>> allocate() const; 242 243 private: 244 bool badState(const char* name) const; 245 246 // The memory descriptor that the MemoryBuilder is building. 247 MemoryDescriptor mDesc; 248 249 // The roles that have been specified via addRole. 250 // This is to check whether a new role has been seen before or not. 251 std::set<CompilationRole> mRoles; 252 253 // Keep track of the data type, scale, zero point, and extra parameters of the target operand. 254 // Other fields will be ignored, including dimensions, lifetime, location, etc. 255 // It is std::nullopt if no usage has been specified yet. 256 std::optional<hal::Operand> mOperand; 257 258 // Once the descriptor has been finished, we should not allow further modifications. 259 bool mFinished = false; 260 261 // The following fields are only valid when finished. 262 263 // The chosen device to allocate the memory. Set to nullptr if there are multiple devices. 264 const Device* mAllocator = nullptr; 265 266 // Whether BLOB mode AHWB is supported on all of the relevant devices of the roles. 267 bool mSupportsAhwb = false; 268 269 // If set to true, allocate() will fallback to Ashmem or AHardwareBuffer if the memory 270 // allocation fails on the chosen device, or if there is no device chosen. 271 bool mShouldFallback = true; 272 }; 273 274 class MemoryAshmem : public Memory { 275 public: 276 // Creates a memory object containing a new android shared memory ("ashmem") 277 // object of the size specified in bytes. Because this ashmem region can be 278 // shared with and accessed by one or more driver processes, MemoryAshmem 279 // has shared ownership over the ashmem region. 280 // 281 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 282 // On error, returns the appropriate NNAPI error code and nullptr. 283 static std::pair<int, std::unique_ptr<MemoryAshmem>> create(uint32_t size); 284 285 // Get a pointer to the ashmem region of memory. The returned pointer is 286 // valid for the lifetime of the MemoryAshmem object. This call always 287 // returns non-null because it was validated during MemoryAshmem::create. 288 uint8_t* getPointer() const; 289 getRunTimePoolInfo()290 std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override { 291 return RunTimePoolInfo::createFromExistingBuffer(getPointer(), kHidlMemory.size()); 292 } 293 294 // prefer using MemoryAshmem::create 295 MemoryAshmem(sp<hal::IMemory> mapped, hal::hidl_memory memory); 296 297 private: 298 const sp<hal::IMemory> kMappedMemory; 299 }; 300 301 class MemoryFd : public Memory { 302 public: 303 // Create a memory object based on input size, prot, and fd that can be sent 304 // across HIDL. This function duplicates the provided fd, and owns the 305 // duplicate. 306 // 307 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 308 // On error, returns the appropriate NNAPI error code and nullptr. 309 static std::pair<int, std::unique_ptr<MemoryFd>> create(size_t size, int prot, int fd, 310 size_t offset); 311 312 // prefer using MemoryFd::create 313 MemoryFd(hal::hidl_memory memory); 314 }; 315 316 class MemoryAHWB : public Memory { 317 public: 318 // Create a memory object to keep track of (but not take ownership of) the 319 // provided AHardwareBuffer handle. 320 // 321 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 322 // On error, returns the appropriate NNAPI error code and nullptr. 323 static std::pair<int, std::unique_ptr<MemoryAHWB>> create(const AHardwareBuffer& ahwb); 324 325 // prefer using MemoryAHWB::create MemoryAHWB(hal::hidl_memory memory,std::unique_ptr<MemoryValidatorBase> validator)326 MemoryAHWB(hal::hidl_memory memory, std::unique_ptr<MemoryValidatorBase> validator) 327 : Memory(std::move(memory), std::move(validator)) {} 328 }; 329 330 class MemoryRuntimeAHWB : public Memory { 331 public: 332 // Create a memory object containing a new BLOB-mode AHardwareBuffer memory 333 // object of the size specified in bytes. The created memory is managed and 334 // owned by the NNAPI runtime. 335 // 336 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 337 // On error, returns the appropriate NNAPI error code and nullptr. 338 static std::pair<int, std::unique_ptr<MemoryRuntimeAHWB>> create(uint32_t size); 339 340 // Get a pointer to the content of the memory. The returned pointer is 341 // valid for the lifetime of the MemoryRuntimeAHWB object. This call always 342 // returns non-null because it was validated during MemoryRuntimeAHWB::create. getPointer()343 uint8_t* getPointer() const { return mBuffer; } 344 getRunTimePoolInfo()345 std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override { 346 return RunTimePoolInfo::createFromExistingBuffer(getPointer(), kHidlMemory.size()); 347 } 348 349 // prefer using MemoryRuntimeAHWB::create 350 MemoryRuntimeAHWB(hal::hidl_memory memory, AHardwareBuffer* ahwb, uint8_t* buffer); 351 ~MemoryRuntimeAHWB(); 352 353 private: 354 AHardwareBuffer* const mAhwb; 355 uint8_t* const mBuffer; 356 }; 357 358 class MemoryFromDevice : public Memory { 359 public: 360 // Create a memory object to keep track of a driver-allocated device memory. 361 // The memory is recognized by the driver via a token. 362 // 363 // On success, returns ANEURALNETWORKS_NO_ERROR and a memory object. 364 // On error, returns the appropriate NNAPI error code and nullptr. 365 static std::pair<int, std::unique_ptr<MemoryFromDevice>> create(sp<hal::IBuffer> buffer, 366 uint32_t token); 367 368 // prefer using MemoryFromDevice::create 369 MemoryFromDevice(sp<hal::IBuffer> buffer, uint32_t token); 370 }; 371 372 using MemoryTracker = ObjectTracker<Memory>; 373 374 } // namespace nn 375 } // namespace android 376 377 #endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_MEMORY_H 378