1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_BUFFER_TRACKER_H 18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_BUFFER_TRACKER_H 19 20 #include <android-base/macros.h> 21 22 #include <map> 23 #include <memory> 24 #include <mutex> 25 #include <set> 26 #include <stack> 27 #include <utility> 28 #include <vector> 29 30 #include "CpuExecutor.h" 31 #include "HalInterfaces.h" 32 #include "Utils.h" 33 34 namespace android::nn { 35 36 // This class manages a CPU buffer allocated on heap and provides validation methods. 37 class ManagedBuffer { 38 public: 39 static std::shared_ptr<ManagedBuffer> create(uint32_t size, std::set<PreparedModelRole> roles, 40 const hal::Operand& operand); 41 42 // Prefer ManagedBuffer::create. 43 ManagedBuffer(std::unique_ptr<uint8_t[]> buffer, uint32_t size, 44 std::set<PreparedModelRole> roles, const hal::Operand& operand); 45 createRunTimePoolInfo()46 RunTimePoolInfo createRunTimePoolInfo() const { 47 return RunTimePoolInfo::createFromExistingBuffer(kBuffer.get(), kSize); 48 } 49 50 // "poolIndex" is the index of this buffer in the request.pools. 51 hal::ErrorStatus validateRequest(uint32_t poolIndex, const hal::Request& request, 52 const hal::IPreparedModel* preparedModel) const; 53 54 // "size" is the byte size of the hidl_memory provided to the copyFrom or copyTo method. 55 hal::ErrorStatus validateCopyFrom(const std::vector<uint32_t>& dimensions, uint32_t size) const; 56 hal::ErrorStatus validateCopyTo(uint32_t size) const; 57 58 bool updateDimensions(const std::vector<uint32_t>& dimensions); 59 void setInitialized(bool initialized); 60 61 private: 62 mutable std::mutex mMutex; 63 const std::unique_ptr<uint8_t[]> kBuffer; 64 const uint32_t kSize; 65 const std::set<PreparedModelRole> kRoles; 66 const hal::OperandType kOperandType; 67 const std::vector<uint32_t> kInitialDimensions; 68 std::vector<uint32_t> mUpdatedDimensions; 69 bool mInitialized = false; 70 }; 71 72 // Keep track of all ManagedBuffers and assign each with a unique token. 73 class BufferTracker : public std::enable_shared_from_this<BufferTracker> { 74 DISALLOW_COPY_AND_ASSIGN(BufferTracker); 75 76 public: 77 // A RAII class to help manage the lifetime of the token. 78 // It is only supposed to be constructed in BufferTracker::add. 79 class Token { 80 DISALLOW_COPY_AND_ASSIGN(Token); 81 82 public: Token(uint32_t token,std::shared_ptr<BufferTracker> tracker)83 Token(uint32_t token, std::shared_ptr<BufferTracker> tracker) 84 : kToken(token), kBufferTracker(std::move(tracker)) {} ~Token()85 ~Token() { kBufferTracker->free(kToken); } get()86 uint32_t get() const { return kToken; } 87 88 private: 89 const uint32_t kToken; 90 const std::shared_ptr<BufferTracker> kBufferTracker; 91 }; 92 93 // The factory of BufferTracker. This ensures that the BufferTracker is always managed by a 94 // shared_ptr. create()95 static std::shared_ptr<BufferTracker> create() { return std::make_shared<BufferTracker>(); } 96 97 // Prefer BufferTracker::create. BufferTracker()98 BufferTracker() : mTokenToBuffers(1) {} 99 100 std::unique_ptr<Token> add(std::shared_ptr<ManagedBuffer> buffer); 101 std::shared_ptr<ManagedBuffer> get(uint32_t token) const; 102 103 private: 104 void free(uint32_t token); 105 106 mutable std::mutex mMutex; 107 std::stack<uint32_t, std::vector<uint32_t>> mFreeTokens; 108 109 // Since the tokens are allocated in a non-sparse way, we use a vector to represent the mapping. 110 // The index of the vector is the token. When the token gets freed, the corresponding entry is 111 // set to nullptr. mTokenToBuffers[0] is always set to nullptr because 0 is an invalid token. 112 std::vector<std::shared_ptr<ManagedBuffer>> mTokenToBuffers; 113 }; 114 115 } // namespace android::nn 116 117 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_BUFFER_TRACKER_H 118