1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_BUFFER_TRACKER_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_BUFFER_TRACKER_H
19 
20 #include <android-base/macros.h>
21 
22 #include <map>
23 #include <memory>
24 #include <mutex>
25 #include <set>
26 #include <stack>
27 #include <utility>
28 #include <vector>
29 
30 #include "CpuExecutor.h"
31 #include "HalInterfaces.h"
32 #include "Utils.h"
33 
34 namespace android::nn {
35 
36 // This class manages a CPU buffer allocated on heap and provides validation methods.
37 class ManagedBuffer {
38    public:
39     static std::shared_ptr<ManagedBuffer> create(uint32_t size, std::set<PreparedModelRole> roles,
40                                                  const hal::Operand& operand);
41 
42     // Prefer ManagedBuffer::create.
43     ManagedBuffer(std::unique_ptr<uint8_t[]> buffer, uint32_t size,
44                   std::set<PreparedModelRole> roles, const hal::Operand& operand);
45 
createRunTimePoolInfo()46     RunTimePoolInfo createRunTimePoolInfo() const {
47         return RunTimePoolInfo::createFromExistingBuffer(kBuffer.get(), kSize);
48     }
49 
50     // "poolIndex" is the index of this buffer in the request.pools.
51     hal::ErrorStatus validateRequest(uint32_t poolIndex, const hal::Request& request,
52                                      const hal::IPreparedModel* preparedModel) const;
53 
54     // "size" is the byte size of the hidl_memory provided to the copyFrom or copyTo method.
55     hal::ErrorStatus validateCopyFrom(const std::vector<uint32_t>& dimensions, uint32_t size) const;
56     hal::ErrorStatus validateCopyTo(uint32_t size) const;
57 
58     bool updateDimensions(const std::vector<uint32_t>& dimensions);
59     void setInitialized(bool initialized);
60 
61    private:
62     mutable std::mutex mMutex;
63     const std::unique_ptr<uint8_t[]> kBuffer;
64     const uint32_t kSize;
65     const std::set<PreparedModelRole> kRoles;
66     const hal::OperandType kOperandType;
67     const std::vector<uint32_t> kInitialDimensions;
68     std::vector<uint32_t> mUpdatedDimensions;
69     bool mInitialized = false;
70 };
71 
72 // Keep track of all ManagedBuffers and assign each with a unique token.
73 class BufferTracker : public std::enable_shared_from_this<BufferTracker> {
74     DISALLOW_COPY_AND_ASSIGN(BufferTracker);
75 
76    public:
77     // A RAII class to help manage the lifetime of the token.
78     // It is only supposed to be constructed in BufferTracker::add.
79     class Token {
80         DISALLOW_COPY_AND_ASSIGN(Token);
81 
82        public:
Token(uint32_t token,std::shared_ptr<BufferTracker> tracker)83         Token(uint32_t token, std::shared_ptr<BufferTracker> tracker)
84             : kToken(token), kBufferTracker(std::move(tracker)) {}
~Token()85         ~Token() { kBufferTracker->free(kToken); }
get()86         uint32_t get() const { return kToken; }
87 
88        private:
89         const uint32_t kToken;
90         const std::shared_ptr<BufferTracker> kBufferTracker;
91     };
92 
93     // The factory of BufferTracker. This ensures that the BufferTracker is always managed by a
94     // shared_ptr.
create()95     static std::shared_ptr<BufferTracker> create() { return std::make_shared<BufferTracker>(); }
96 
97     // Prefer BufferTracker::create.
BufferTracker()98     BufferTracker() : mTokenToBuffers(1) {}
99 
100     std::unique_ptr<Token> add(std::shared_ptr<ManagedBuffer> buffer);
101     std::shared_ptr<ManagedBuffer> get(uint32_t token) const;
102 
103    private:
104     void free(uint32_t token);
105 
106     mutable std::mutex mMutex;
107     std::stack<uint32_t, std::vector<uint32_t>> mFreeTokens;
108 
109     // Since the tokens are allocated in a non-sparse way, we use a vector to represent the mapping.
110     // The index of the vector is the token. When the token gets freed, the corresponding entry is
111     // set to nullptr. mTokenToBuffers[0] is always set to nullptr because 0 is an invalid token.
112     std::vector<std::shared_ptr<ManagedBuffer>> mTokenToBuffers;
113 };
114 
115 }  // namespace android::nn
116 
117 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_BUFFER_TRACKER_H
118