1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Callbacks"
18
19 #include "Callbacks.h"
20
21 #include <android-base/logging.h>
22 #include <limits>
23 #include <utility>
24 #include <vector>
25
26 namespace android::nn {
27
28 using namespace hal;
29
30 constexpr Timing kNoTiming = {.timeOnDevice = std::numeric_limits<uint64_t>::max(),
31 .timeInDriver = std::numeric_limits<uint64_t>::max()};
32
33 // PreparedModelCallback methods begin here
34
notifyInternal(bool deadObject,ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModel)35 Return<void> PreparedModelCallback::notifyInternal(bool deadObject, ErrorStatus errorStatus,
36 const sp<V1_0::IPreparedModel>& preparedModel) {
37 {
38 std::lock_guard<std::mutex> hold(mMutex);
39
40 // quick-return if object has already been notified
41 if (mNotified) {
42 return Void();
43 }
44
45 // store results and mark as notified
46 mDeadObject = deadObject;
47 mErrorStatus = errorStatus;
48 mPreparedModel = preparedModel;
49 mNotified = true;
50 }
51
52 mCondition.notify_all();
53 return Void();
54 }
55
notify(V1_0::ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModel)56 Return<void> PreparedModelCallback::notify(V1_0::ErrorStatus errorStatus,
57 const sp<V1_0::IPreparedModel>& preparedModel) {
58 return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), preparedModel);
59 }
60
notify_1_2(V1_0::ErrorStatus errorStatus,const sp<V1_2::IPreparedModel> & preparedModel)61 Return<void> PreparedModelCallback::notify_1_2(V1_0::ErrorStatus errorStatus,
62 const sp<V1_2::IPreparedModel>& preparedModel) {
63 return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), preparedModel);
64 }
65
notify_1_3(ErrorStatus errorStatus,const sp<V1_3::IPreparedModel> & preparedModel)66 Return<void> PreparedModelCallback::notify_1_3(ErrorStatus errorStatus,
67 const sp<V1_3::IPreparedModel>& preparedModel) {
68 return notifyInternal(false, errorStatus, preparedModel);
69 }
70
notifyAsDeadObject()71 void PreparedModelCallback::notifyAsDeadObject() {
72 notifyInternal(true, ErrorStatus::GENERAL_FAILURE, nullptr);
73 }
74
wait() const75 void PreparedModelCallback::wait() const {
76 std::unique_lock<std::mutex> lock(mMutex);
77 mCondition.wait(lock, [this] { return mNotified; });
78 }
79
getStatus() const80 ErrorStatus PreparedModelCallback::getStatus() const {
81 wait();
82 return mErrorStatus;
83 }
84
getPreparedModel() const85 sp<V1_0::IPreparedModel> PreparedModelCallback::getPreparedModel() const {
86 wait();
87 return mPreparedModel;
88 }
89
isDeadObject() const90 bool PreparedModelCallback::isDeadObject() const {
91 wait();
92 return mDeadObject;
93 }
94
95 // ExecutionCallback methods begin here
96
notify(V1_0::ErrorStatus errorStatus)97 Return<void> ExecutionCallback::notify(V1_0::ErrorStatus errorStatus) {
98 return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), {}, kNoTiming);
99 }
100
notify_1_2(V1_0::ErrorStatus errorStatus,const hidl_vec<OutputShape> & outputShapes,const Timing & timing)101 Return<void> ExecutionCallback::notify_1_2(V1_0::ErrorStatus errorStatus,
102 const hidl_vec<OutputShape>& outputShapes,
103 const Timing& timing) {
104 return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), outputShapes, timing);
105 }
106
notify_1_3(V1_3::ErrorStatus errorStatus,const hidl_vec<OutputShape> & outputShapes,const Timing & timing)107 Return<void> ExecutionCallback::notify_1_3(V1_3::ErrorStatus errorStatus,
108 const hidl_vec<OutputShape>& outputShapes,
109 const Timing& timing) {
110 return notifyInternal(false, errorStatus, outputShapes, timing);
111 }
112
notifyAsDeadObject()113 void ExecutionCallback::notifyAsDeadObject() {
114 notifyInternal(true, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
115 }
116
wait() const117 void ExecutionCallback::wait() const {
118 std::unique_lock<std::mutex> lock(mMutex);
119 mCondition.wait(lock, [this] { return mNotified; });
120
121 /*
122 * Note that we cannot call std::thread::join from ExecutionCallback's
123 * destructor: ExecutionCallback is intended to be reference counted, and it
124 * is possible that the reference count drops to zero in the bound thread,
125 * causing the bound thread to call this destructor. If a thread tries to
126 * join itself, it throws an exception, producing a message like the
127 * following:
128 *
129 * terminating with uncaught exception of type std::__1::system_error:
130 * thread::join failed: Resource deadlock would occur
131 */
132 if (mThread.joinable()) {
133 mThread.join();
134 }
135 }
136
getStatus() const137 ErrorStatus ExecutionCallback::getStatus() const {
138 wait();
139 return mErrorStatus;
140 }
141
getOutputShapes() const142 const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() const {
143 wait();
144 return mOutputShapes;
145 }
146
getTiming() const147 Timing ExecutionCallback::getTiming() const {
148 wait();
149 return mTiming;
150 }
151
isDeadObject() const152 bool ExecutionCallback::isDeadObject() const {
153 wait();
154 return mDeadObject;
155 }
156
bindThread(std::thread asyncThread)157 bool ExecutionCallback::bindThread(std::thread asyncThread) {
158 std::lock_guard<std::mutex> lock(mMutex);
159
160 // Ensure ExecutionCallback object does not already have a thread bound
161 if (mThread.joinable()) {
162 LOG(ERROR) << "ExecutionCallback::bindThread -- a thread has already been bound to this "
163 "callback object";
164 return false;
165 }
166
167 // Ensure the new thread is valid
168 if (!asyncThread.joinable()) {
169 LOG(ERROR) << "ExecutionCallback::bindThread -- the new thread is not joinable";
170 return false;
171 }
172
173 mThread = std::move(asyncThread);
174 return true;
175 }
176
setOnFinish(const ExecutionFinish & finish)177 void ExecutionCallback::setOnFinish(const ExecutionFinish& finish) {
178 std::lock_guard<std::mutex> hold(mMutex);
179
180 // Ensure ExecutionCallback object does not already have a "finish" callback
181 if (mOnFinish != nullptr) {
182 LOG(ERROR) << "ExecutionCallback::setOnFinish -- object already has a \"finish\" callback";
183 return;
184 }
185
186 // Ensure new "finish" callback is valid
187 if (finish == nullptr) {
188 LOG(ERROR) << "ExecutionCallback::setOnFinish -- \"finish\" callback is invalid";
189 return;
190 }
191
192 // Essure ExecutionCallback object has not already been notified
193 if (mNotified) {
194 LOG(ERROR) << "ExecutionCallback::setOnFinish -- ExecutionCallback has already been "
195 "notified with results";
196 return;
197 }
198
199 mOnFinish = finish;
200 }
201
notifyInternal(bool deadObject,ErrorStatus errorStatus,std::vector<OutputShape> outputShapes,Timing timing)202 Return<void> ExecutionCallback::notifyInternal(bool deadObject, ErrorStatus errorStatus,
203 std::vector<OutputShape> outputShapes,
204 Timing timing) {
205 // check results
206 if (!deadObject) {
207 if (errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
208 // outputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
209 if (outputShapes.size() == 0) {
210 LOG(ERROR)
211 << "Notified with empty output shape vector when OUTPUT_INSUFFICIENT_SIZE";
212 errorStatus = ErrorStatus::GENERAL_FAILURE;
213 outputShapes = {};
214 timing = kNoTiming;
215 }
216 } else if (errorStatus != ErrorStatus::NONE) {
217 // outputShapes must be empty if errorStatus is neither NONE nor
218 // OUTPUT_INSUFFICIENT_SIZE.
219 if (outputShapes.size() != 0) {
220 LOG(ERROR) << "Notified with non-empty output shape vector when error status is "
221 "neither NONE nor OUTPUT_INSUFFICIENT_SIZE";
222 errorStatus = ErrorStatus::GENERAL_FAILURE;
223 outputShapes = {};
224 timing = kNoTiming;
225 }
226 }
227 }
228
229 // store results
230 {
231 std::lock_guard<std::mutex> hold(mMutex);
232
233 // quick-return if object has already been notified
234 if (mNotified) {
235 return Void();
236 }
237
238 mDeadObject = deadObject;
239 mErrorStatus = errorStatus;
240 mOutputShapes = std::move(outputShapes);
241 mTiming = timing;
242 mNotified = true;
243
244 if (mOnFinish != nullptr) {
245 ErrorStatus status = mOnFinish(mErrorStatus, mOutputShapes);
246 mOnFinish = nullptr;
247 if (status != ErrorStatus::NONE) {
248 mErrorStatus = status;
249 }
250 }
251 }
252 mCondition.notify_all();
253 return Void();
254 }
255
256 } // namespace android::nn
257