1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Callbacks"
18 
19 #include "Callbacks.h"
20 
21 #include <android-base/logging.h>
22 #include <limits>
23 #include <utility>
24 #include <vector>
25 
26 namespace android::nn {
27 
28 using namespace hal;
29 
30 constexpr Timing kNoTiming = {.timeOnDevice = std::numeric_limits<uint64_t>::max(),
31                               .timeInDriver = std::numeric_limits<uint64_t>::max()};
32 
33 // PreparedModelCallback methods begin here
34 
notifyInternal(bool deadObject,ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModel)35 Return<void> PreparedModelCallback::notifyInternal(bool deadObject, ErrorStatus errorStatus,
36                                                    const sp<V1_0::IPreparedModel>& preparedModel) {
37     {
38         std::lock_guard<std::mutex> hold(mMutex);
39 
40         // quick-return if object has already been notified
41         if (mNotified) {
42             return Void();
43         }
44 
45         // store results and mark as notified
46         mDeadObject = deadObject;
47         mErrorStatus = errorStatus;
48         mPreparedModel = preparedModel;
49         mNotified = true;
50     }
51 
52     mCondition.notify_all();
53     return Void();
54 }
55 
notify(V1_0::ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModel)56 Return<void> PreparedModelCallback::notify(V1_0::ErrorStatus errorStatus,
57                                            const sp<V1_0::IPreparedModel>& preparedModel) {
58     return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), preparedModel);
59 }
60 
notify_1_2(V1_0::ErrorStatus errorStatus,const sp<V1_2::IPreparedModel> & preparedModel)61 Return<void> PreparedModelCallback::notify_1_2(V1_0::ErrorStatus errorStatus,
62                                                const sp<V1_2::IPreparedModel>& preparedModel) {
63     return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), preparedModel);
64 }
65 
notify_1_3(ErrorStatus errorStatus,const sp<V1_3::IPreparedModel> & preparedModel)66 Return<void> PreparedModelCallback::notify_1_3(ErrorStatus errorStatus,
67                                                const sp<V1_3::IPreparedModel>& preparedModel) {
68     return notifyInternal(false, errorStatus, preparedModel);
69 }
70 
notifyAsDeadObject()71 void PreparedModelCallback::notifyAsDeadObject() {
72     notifyInternal(true, ErrorStatus::GENERAL_FAILURE, nullptr);
73 }
74 
wait() const75 void PreparedModelCallback::wait() const {
76     std::unique_lock<std::mutex> lock(mMutex);
77     mCondition.wait(lock, [this] { return mNotified; });
78 }
79 
getStatus() const80 ErrorStatus PreparedModelCallback::getStatus() const {
81     wait();
82     return mErrorStatus;
83 }
84 
getPreparedModel() const85 sp<V1_0::IPreparedModel> PreparedModelCallback::getPreparedModel() const {
86     wait();
87     return mPreparedModel;
88 }
89 
isDeadObject() const90 bool PreparedModelCallback::isDeadObject() const {
91     wait();
92     return mDeadObject;
93 }
94 
95 // ExecutionCallback methods begin here
96 
notify(V1_0::ErrorStatus errorStatus)97 Return<void> ExecutionCallback::notify(V1_0::ErrorStatus errorStatus) {
98     return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), {}, kNoTiming);
99 }
100 
notify_1_2(V1_0::ErrorStatus errorStatus,const hidl_vec<OutputShape> & outputShapes,const Timing & timing)101 Return<void> ExecutionCallback::notify_1_2(V1_0::ErrorStatus errorStatus,
102                                            const hidl_vec<OutputShape>& outputShapes,
103                                            const Timing& timing) {
104     return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), outputShapes, timing);
105 }
106 
notify_1_3(V1_3::ErrorStatus errorStatus,const hidl_vec<OutputShape> & outputShapes,const Timing & timing)107 Return<void> ExecutionCallback::notify_1_3(V1_3::ErrorStatus errorStatus,
108                                            const hidl_vec<OutputShape>& outputShapes,
109                                            const Timing& timing) {
110     return notifyInternal(false, errorStatus, outputShapes, timing);
111 }
112 
notifyAsDeadObject()113 void ExecutionCallback::notifyAsDeadObject() {
114     notifyInternal(true, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
115 }
116 
wait() const117 void ExecutionCallback::wait() const {
118     std::unique_lock<std::mutex> lock(mMutex);
119     mCondition.wait(lock, [this] { return mNotified; });
120 
121     /*
122      * Note that we cannot call std::thread::join from ExecutionCallback's
123      * destructor: ExecutionCallback is intended to be reference counted, and it
124      * is possible that the reference count drops to zero in the bound thread,
125      * causing the bound thread to call this destructor. If a thread tries to
126      * join itself, it throws an exception, producing a message like the
127      * following:
128      *
129      *     terminating with uncaught exception of type std::__1::system_error:
130      *     thread::join failed: Resource deadlock would occur
131      */
132     if (mThread.joinable()) {
133         mThread.join();
134     }
135 }
136 
getStatus() const137 ErrorStatus ExecutionCallback::getStatus() const {
138     wait();
139     return mErrorStatus;
140 }
141 
getOutputShapes() const142 const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() const {
143     wait();
144     return mOutputShapes;
145 }
146 
getTiming() const147 Timing ExecutionCallback::getTiming() const {
148     wait();
149     return mTiming;
150 }
151 
isDeadObject() const152 bool ExecutionCallback::isDeadObject() const {
153     wait();
154     return mDeadObject;
155 }
156 
bindThread(std::thread asyncThread)157 bool ExecutionCallback::bindThread(std::thread asyncThread) {
158     std::lock_guard<std::mutex> lock(mMutex);
159 
160     // Ensure ExecutionCallback object does not already have a thread bound
161     if (mThread.joinable()) {
162         LOG(ERROR) << "ExecutionCallback::bindThread -- a thread has already been bound to this "
163                       "callback object";
164         return false;
165     }
166 
167     // Ensure the new thread is valid
168     if (!asyncThread.joinable()) {
169         LOG(ERROR) << "ExecutionCallback::bindThread -- the new thread is not joinable";
170         return false;
171     }
172 
173     mThread = std::move(asyncThread);
174     return true;
175 }
176 
setOnFinish(const ExecutionFinish & finish)177 void ExecutionCallback::setOnFinish(const ExecutionFinish& finish) {
178     std::lock_guard<std::mutex> hold(mMutex);
179 
180     // Ensure ExecutionCallback object does not already have a "finish" callback
181     if (mOnFinish != nullptr) {
182         LOG(ERROR) << "ExecutionCallback::setOnFinish -- object already has a \"finish\" callback";
183         return;
184     }
185 
186     // Ensure new "finish" callback is valid
187     if (finish == nullptr) {
188         LOG(ERROR) << "ExecutionCallback::setOnFinish -- \"finish\" callback is invalid";
189         return;
190     }
191 
192     // Essure ExecutionCallback object has not already been notified
193     if (mNotified) {
194         LOG(ERROR) << "ExecutionCallback::setOnFinish -- ExecutionCallback has already been "
195                       "notified with results";
196         return;
197     }
198 
199     mOnFinish = finish;
200 }
201 
notifyInternal(bool deadObject,ErrorStatus errorStatus,std::vector<OutputShape> outputShapes,Timing timing)202 Return<void> ExecutionCallback::notifyInternal(bool deadObject, ErrorStatus errorStatus,
203                                                std::vector<OutputShape> outputShapes,
204                                                Timing timing) {
205     // check results
206     if (!deadObject) {
207         if (errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
208             // outputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
209             if (outputShapes.size() == 0) {
210                 LOG(ERROR)
211                         << "Notified with empty output shape vector when OUTPUT_INSUFFICIENT_SIZE";
212                 errorStatus = ErrorStatus::GENERAL_FAILURE;
213                 outputShapes = {};
214                 timing = kNoTiming;
215             }
216         } else if (errorStatus != ErrorStatus::NONE) {
217             // outputShapes must be empty if errorStatus is neither NONE nor
218             // OUTPUT_INSUFFICIENT_SIZE.
219             if (outputShapes.size() != 0) {
220                 LOG(ERROR) << "Notified with non-empty output shape vector when error status is "
221                               "neither NONE nor OUTPUT_INSUFFICIENT_SIZE";
222                 errorStatus = ErrorStatus::GENERAL_FAILURE;
223                 outputShapes = {};
224                 timing = kNoTiming;
225             }
226         }
227     }
228 
229     // store results
230     {
231         std::lock_guard<std::mutex> hold(mMutex);
232 
233         // quick-return if object has already been notified
234         if (mNotified) {
235             return Void();
236         }
237 
238         mDeadObject = deadObject;
239         mErrorStatus = errorStatus;
240         mOutputShapes = std::move(outputShapes);
241         mTiming = timing;
242         mNotified = true;
243 
244         if (mOnFinish != nullptr) {
245             ErrorStatus status = mOnFinish(mErrorStatus, mOutputShapes);
246             mOnFinish = nullptr;
247             if (status != ErrorStatus::NONE) {
248                 mErrorStatus = status;
249             }
250         }
251     }
252     mCondition.notify_all();
253     return Void();
254 }
255 
256 }  // namespace android::nn
257