1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "VersionedInterfaces"
18 
19 #include "VersionedInterfaces.h"
20 
21 #include <fcntl.h>
22 
23 #include <android-base/logging.h>
24 #include <android-base/properties.h>
25 #include <android-base/scopeguard.h>
26 #include <android-base/thread_annotations.h>
27 #include <cutils/native_handle.h>
28 
29 #include <algorithm>
30 #include <chrono>
31 #include <functional>
32 #include <memory>
33 #include <string>
34 #include <tuple>
35 #include <type_traits>
36 #include <utility>
37 #include <vector>
38 
39 #include "Callbacks.h"
40 #include "ExecutionBurstController.h"
41 #include "MetaModel.h"
42 #include "Tracing.h"
43 #include "Utils.h"
44 
45 /*
46  * Some notes about HIDL interface objects and lifetimes across processes:
47  *
48  * All HIDL interface objects inherit from IBase, which itself inherits from
49  * ::android::RefBase. As such, all HIDL interface objects are reference counted
50  * and must be owned through ::android::sp (or referenced through ::android::wp).
51  * Allocating RefBase objects on the stack will log errors and may result in
52  * crashes, and deleting a RefBase object through another means (e.g., "delete",
53  * "free", or RAII-cleanup through std::unique_ptr or some equivalent) will
54  * result in double-free and/or use-after-free undefined behavior.
55  *
56  * HIDL/Binder manages the reference count of HIDL interface objects
57  * automatically across processes. If a process that references (but did not
58  * create) the HIDL interface object dies, HIDL/Binder ensures any reference
59  * count it held is properly released. (Caveat: it might be possible that
60  * HIDL/Binder behave strangely with ::android::wp references.)
61  *
62  * If the process which created the HIDL interface object dies, any call on this
63  * object from another process will result in a HIDL transport error with the
64  * code DEAD_OBJECT.
65  */
66 
67 /*
68  * Some notes about asynchronous calls across HIDL:
69  *
70  * For synchronous calls across HIDL, if an error occurs after the function was
71  * called but before it returns, HIDL will return a transport error. For
72  * example, if the message cannot be delivered to the server process or if the
73  * server process dies before returning a result, HIDL will return from the
74  * function with the appropriate transport error in the Return<> object which
75  * can be queried with Return<>::isOk(), Return<>::isDeadObject(),
76  * Return<>::description(), etc.
77  *
78  * However, HIDL offers no such error management in the case of asynchronous
79  * calls. By default, if the client launches an asynchronous task and the server
80  * fails to return a result through the callback, the client will be left
81  * waiting indefinitely for a result it will never receive.
82  *
83  * In the NNAPI, IDevice::prepareModel* and IPreparedModel::execute* (but not
84  * IPreparedModel::executeSynchronously*) are asynchronous calls across HIDL.
85  * Specifically, these asynchronous functions are called with a HIDL interface
86  * callback object (IPrepareModelCallback for IDevice::prepareModel* and
87  * IExecutionCallback for IPreparedModel::execute*) and are expected to quickly
88  * return, and the results are returned at a later time through these callback
89  * objects.
90  *
91  * To protect against the case when the server dies after the asynchronous task
92  * was called successfully but before the results could be returned, HIDL
93  * provides an object called a "hidl_death_recipient", which can be used to
94  * detect when an interface object (and more generally, the server process) has
95  * died. VersionedInterfaces uses hidl_death_recipients to detect when the
96  * driver process has died, and VersionedInterfaces will unblock any thread
97  * waiting on the results of a callback object that may otherwise not be
98  * signaled.
99  */
100 
101 namespace android {
102 namespace nn {
103 
104 // anonymous namespace
105 namespace {
106 
107 using namespace hal;
108 
109 const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
110 
sendFailureMessage(IPreparedModelCallback * cb)111 void sendFailureMessage(IPreparedModelCallback* cb) {
112     CHECK(cb != nullptr);
113     cb->notify_1_3(ErrorStatus::GENERAL_FAILURE, nullptr);
114 }
115 
116 // This class is thread safe
117 template <typename Callback>
118 class DeathHandler : public hidl_death_recipient {
119    public:
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)120     void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
121         LOG(ERROR) << "DeathHandler::serviceDied -- service unexpectedly died!";
122         std::lock_guard<std::mutex> hold(mMutex);
123         std::for_each(mCallbacks.begin(), mCallbacks.end(),
124                       [](const auto& cb) { cb->notifyAsDeadObject(); });
125     }
126 
protectCallback(const sp<Callback> & callback)127     [[nodiscard]] base::ScopeGuard<std::function<void()>> protectCallback(
128             const sp<Callback>& callback) {
129         registerCallback(callback);
130         return ::android::base::make_scope_guard(
131                 [this, callback] { unregisterCallback(callback); });
132     }
133 
134    private:
registerCallback(const sp<Callback> & callback)135     void registerCallback(const sp<Callback>& callback) {
136         std::lock_guard<std::mutex> hold(mMutex);
137         mCallbacks.push_back(callback);
138     }
139 
unregisterCallback(const sp<Callback> & callback)140     void unregisterCallback(const sp<Callback>& callback) {
141         std::lock_guard<std::mutex> hold(mMutex);
142         mCallbacks.erase(std::remove(mCallbacks.begin(), mCallbacks.end(), callback),
143                          mCallbacks.end());
144     }
145 
146     std::mutex mMutex;
147     std::vector<sp<Callback>> mCallbacks GUARDED_BY(mMutex);
148 };
149 
150 }  // anonymous namespace
151 
152 class IDeviceDeathHandler : public DeathHandler<PreparedModelCallback> {};
153 class IPreparedModelDeathHandler : public DeathHandler<ExecutionCallback> {};
154 
makeVersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel)155 static std::pair<int, std::shared_ptr<VersionedIPreparedModel>> makeVersionedIPreparedModel(
156         sp<V1_0::IPreparedModel> preparedModel) {
157     CHECK(preparedModel != nullptr)
158             << "makeVersionedIPreparedModel passed invalid preparedModel object.";
159 
160     // create death handler object
161     sp<IPreparedModelDeathHandler> deathHandler = new IPreparedModelDeathHandler();
162 
163     // linkToDeath registers a callback that will be invoked on service death to
164     // proactively handle service crashes. If the linkToDeath call fails,
165     // asynchronous calls are susceptible to hangs if the service crashes before
166     // providing the response.
167     const Return<bool> ret = preparedModel->linkToDeath(deathHandler, 0);
168     if (ret.isDeadObject()) {
169         LOG(ERROR) << "makeVersionedIPreparedModel failed to register a death recipient for the "
170                       "IPreparedModel object because the IPreparedModel object is dead.";
171         return {ANEURALNETWORKS_DEAD_OBJECT, nullptr};
172     }
173     if (!ret.isOk()) {
174         LOG(ERROR) << "makeVersionedIPreparedModel failed to register a death recipient for the "
175                       "IPreparedModel object because of failure: "
176                    << ret.description();
177         return {ANEURALNETWORKS_OP_FAILED, nullptr};
178     }
179     if (ret != true) {
180         LOG(ERROR) << "makeVersionedIPreparedModel failed to register a death recipient for the "
181                       "IPreparedModel object.";
182         return {ANEURALNETWORKS_OP_FAILED, nullptr};
183     }
184 
185     // return a valid VersionedIPreparedModel object
186     return {ANEURALNETWORKS_NO_ERROR, std::make_shared<VersionedIPreparedModel>(
187                                               std::move(preparedModel), std::move(deathHandler))};
188 }
189 
VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,sp<IPreparedModelDeathHandler> deathHandler)190 VersionedIPreparedModel::VersionedIPreparedModel(sp<V1_0::IPreparedModel> preparedModel,
191                                                  sp<IPreparedModelDeathHandler> deathHandler)
192     : mPreparedModelV1_0(std::move(preparedModel)),
193       mPreparedModelV1_2(V1_2::IPreparedModel::castFrom(mPreparedModelV1_0).withDefault(nullptr)),
194       mPreparedModelV1_3(V1_3::IPreparedModel::castFrom(mPreparedModelV1_0).withDefault(nullptr)),
195       mDeathHandler(std::move(deathHandler)) {}
196 
~VersionedIPreparedModel()197 VersionedIPreparedModel::~VersionedIPreparedModel() {
198     // It is safe to ignore any errors resulting from this unlinkToDeath call
199     // because the VersionedIPreparedModel object is already being destroyed and
200     // its underlying IPreparedModel object is no longer being used by the NN
201     // runtime.
202     mPreparedModelV1_0->unlinkToDeath(mDeathHandler).isOk();
203 }
204 
executeAsynchronously(const Request & request,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration) const205 std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::executeAsynchronously(
206         const Request& request, MeasureTiming measure, const std::optional<Deadline>& deadline,
207         const OptionalTimeoutDuration& loopTimeoutDuration) const {
208     const auto failDeadObject = []() -> std::tuple<int, std::vector<OutputShape>, Timing> {
209         return {ANEURALNETWORKS_DEAD_OBJECT, {}, kNoTiming};
210     };
211     const auto failWithStatus = [](ErrorStatus status) {
212         return getExecutionResult(status, {}, kNoTiming);
213     };
214     const auto getResults = [failDeadObject](const ExecutionCallback& cb) {
215         if (cb.isDeadObject()) {
216             return failDeadObject();
217         }
218         return getExecutionResult(cb.getStatus(), cb.getOutputShapes(), cb.getTiming());
219     };
220 
221     const sp<ExecutionCallback> callback = new ExecutionCallback();
222     const auto scoped = mDeathHandler->protectCallback(callback);
223 
224     // version 1.3+ HAL
225     if (mPreparedModelV1_3 != nullptr) {
226         const auto otp = makeTimePoint(deadline);
227         Return<ErrorStatus> ret = mPreparedModelV1_3->execute_1_3(request, measure, otp,
228                                                                   loopTimeoutDuration, callback);
229         if (ret.isDeadObject()) {
230             LOG(ERROR) << "execute_1_3 failure: " << ret.description();
231             return failDeadObject();
232         }
233         if (!ret.isOk()) {
234             LOG(ERROR) << "execute_1_3 failure: " << ret.description();
235             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
236         }
237         if (ret != ErrorStatus::NONE) {
238             LOG(ERROR) << "execute_1_3 returned " << toString(static_cast<ErrorStatus>(ret));
239             return failWithStatus(ret);
240         }
241         callback->wait();
242         return getResults(*callback);
243     }
244 
245     // version 1.2 HAL
246     if (mPreparedModelV1_2 != nullptr) {
247         const bool compliant = compliantWithV1_2(request);
248         if (!compliant) {
249             LOG(ERROR) << "Could not handle execute_1_2!";
250             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
251         }
252         const V1_0::Request request12 = convertToV1_2(request);
253         Return<V1_0::ErrorStatus> ret =
254                 mPreparedModelV1_2->execute_1_2(request12, measure, callback);
255         if (ret.isDeadObject()) {
256             LOG(ERROR) << "execute_1_2 failure: " << ret.description();
257             return failDeadObject();
258         }
259         if (!ret.isOk()) {
260             LOG(ERROR) << "execute_1_2 failure: " << ret.description();
261             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
262         }
263         const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
264         if (status != V1_0::ErrorStatus::NONE) {
265             LOG(ERROR) << "execute_1_2 returned " << toString(status);
266             return failWithStatus(convertToV1_3(status));
267         }
268         callback->wait();
269         return getResults(*callback);
270     }
271 
272     // version 1.0 HAL
273     if (mPreparedModelV1_0 != nullptr) {
274         const bool compliant = compliantWithV1_0(request);
275         if (!compliant) {
276             LOG(ERROR) << "Could not handle execute!";
277             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
278         }
279         const V1_0::Request request10 = convertToV1_0(request);
280         Return<V1_0::ErrorStatus> ret = mPreparedModelV1_0->execute(request10, callback);
281         if (ret.isDeadObject()) {
282             LOG(ERROR) << "execute failure: " << ret.description();
283             return failDeadObject();
284         }
285         if (!ret.isOk()) {
286             LOG(ERROR) << "execute failure: " << ret.description();
287             return failWithStatus(ErrorStatus::GENERAL_FAILURE);
288         }
289         const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
290         if (status != V1_0::ErrorStatus::NONE) {
291             LOG(ERROR) << "execute returned " << toString(status);
292             return failWithStatus(convertToV1_3(status));
293         }
294         callback->wait();
295         return getResults(*callback);
296     }
297 
298     // No prepared model available
299     LOG(ERROR) << "executeAsynchronously called with no preparedModel";
300     return failWithStatus(ErrorStatus::GENERAL_FAILURE);
301 }
302 
executeSynchronously(const Request & request,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration) const303 std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::executeSynchronously(
304         const Request& request, MeasureTiming measure, const std::optional<Deadline>& deadline,
305         const OptionalTimeoutDuration& loopTimeoutDuration) const {
306     const std::tuple<int, std::vector<OutputShape>, Timing> kDeadObject = {
307             ANEURALNETWORKS_DEAD_OBJECT, {}, kNoTiming};
308     const auto kFailure = getExecutionResult(ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
309 
310     // version 1.3+ HAL
311     if (mPreparedModelV1_3 != nullptr) {
312         std::tuple<int, std::vector<OutputShape>, Timing> result;
313         const auto otp = makeTimePoint(deadline);
314         Return<void> ret = mPreparedModelV1_3->executeSynchronously_1_3(
315                 request, measure, otp, loopTimeoutDuration,
316                 [&result](ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
317                           const Timing& timing) {
318                     result = getExecutionResult(error, outputShapes, timing);
319                 });
320         if (ret.isDeadObject()) {
321             LOG(ERROR) << "executeSynchronously_1_3 failure: " << ret.description();
322             return kDeadObject;
323         }
324         if (!ret.isOk()) {
325             LOG(ERROR) << "executeSynchronously_1_3 failure: " << ret.description();
326             return kFailure;
327         }
328         return result;
329     }
330 
331     // version 1.2 HAL
332     if (mPreparedModelV1_2 != nullptr) {
333         const bool compliant = compliantWithV1_2(request);
334         if (!compliant) {
335             LOG(ERROR) << "Could not handle executeSynchronously!";
336             return kFailure;
337         }
338         const V1_0::Request request12 = convertToV1_2(request);
339 
340         std::tuple<int, std::vector<OutputShape>, Timing> result;
341         Return<void> ret = mPreparedModelV1_2->executeSynchronously(
342                 request12, measure,
343                 [&result](V1_0::ErrorStatus error, const hidl_vec<OutputShape>& outputShapes,
344                           const Timing& timing) {
345                     result = getExecutionResult(convertToV1_3(error), outputShapes, timing);
346                 });
347         if (ret.isDeadObject()) {
348             LOG(ERROR) << "executeSynchronously failure: " << ret.description();
349             return kDeadObject;
350         }
351         if (!ret.isOk()) {
352             LOG(ERROR) << "executeSynchronously failure: " << ret.description();
353             return kFailure;
354         }
355         return result;
356     }
357 
358     // Fallback to asynchronous execution.
359     return executeAsynchronously(request, measure, deadline, loopTimeoutDuration);
360 }
361 
execute(const Request & request,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,bool preferSynchronous) const362 std::tuple<int, std::vector<OutputShape>, Timing> VersionedIPreparedModel::execute(
363         const Request& request, MeasureTiming measure, const std::optional<Deadline>& deadline,
364         const OptionalTimeoutDuration& loopTimeoutDuration, bool preferSynchronous) const {
365     if (preferSynchronous) {
366         VLOG(EXECUTION) << "Before executeSynchronously() " << SHOW_IF_DEBUG(toString(request));
367         return executeSynchronously(request, measure, deadline, loopTimeoutDuration);
368     }
369 
370     VLOG(EXECUTION) << "Before executeAsynchronously() " << SHOW_IF_DEBUG(toString(request));
371     return executeAsynchronously(request, measure, deadline, loopTimeoutDuration);
372 }
373 
374 // This is the amount of time the ExecutionBurstController should spend polling
375 // the FMQ to see if it has data available before it should fall back to
376 // waiting on the futex.
getPollingTimeWindow()377 static std::chrono::microseconds getPollingTimeWindow() {
378     constexpr int32_t defaultPollingTimeWindow = 50;
379 #ifdef NN_DEBUGGABLE
380     constexpr int32_t minPollingTimeWindow = 0;
381     const int32_t selectedPollingTimeWindow =
382             base::GetIntProperty("debug.nn.burst-conrtoller-polling-window",
383                                  defaultPollingTimeWindow, minPollingTimeWindow);
384     return std::chrono::microseconds{selectedPollingTimeWindow};
385 #else
386     return std::chrono::microseconds{defaultPollingTimeWindow};
387 #endif  // NN_DEBUGGABLE
388 }
389 
configureExecutionBurst(bool preferPowerOverLatency) const390 std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
391         bool preferPowerOverLatency) const {
392     if (mPreparedModelV1_2 == nullptr) {
393         return nullptr;
394     }
395     const auto pollingTimeWindow =
396             (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
397     return ExecutionBurstController::create(mPreparedModelV1_2, pollingTimeWindow);
398 }
399 
getCapabilitiesFunction(V1_3::IDevice * device)400 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_3::IDevice* device) {
401     CHECK(device != nullptr);
402     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_3");
403     const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
404     std::pair<ErrorStatus, Capabilities> result = kFailure;
405     const Return<void> ret = device->getCapabilities_1_3(
406             [&result](ErrorStatus error, const Capabilities& capabilities) {
407                 result = std::make_pair(error, capabilities);
408             });
409     if (!ret.isOk()) {
410         LOG(ERROR) << "getCapabilities_1_3 failure: " << ret.description();
411         return kFailure;
412     }
413     return result;
414 }
415 
416 std::tuple<int, hal::hidl_handle, sp<hal::IFencedExecutionCallback>, hal::Timing>
executeFenced(const hal::Request & request,const hal::hidl_vec<hal::hidl_handle> & waitFor,MeasureTiming measure,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const hal::OptionalTimeoutDuration & timeoutDurationAfterFence)417 VersionedIPreparedModel::executeFenced(
418         const hal::Request& request, const hal::hidl_vec<hal::hidl_handle>& waitFor,
419         MeasureTiming measure, const std::optional<Deadline>& deadline,
420         const OptionalTimeoutDuration& loopTimeoutDuration,
421         const hal::OptionalTimeoutDuration& timeoutDurationAfterFence) {
422     // version 1.3+ HAL
423     hal::hidl_handle syncFence;
424     sp<hal::IFencedExecutionCallback> dispatchCallback;
425     hal::Timing timing = {UINT64_MAX, UINT64_MAX};
426     if (mPreparedModelV1_3 != nullptr) {
427         ErrorStatus errorStatus;
428         const auto otp = makeTimePoint(deadline);
429         Return<void> ret = mPreparedModelV1_3->executeFenced(
430                 request, waitFor, measure, otp, loopTimeoutDuration, timeoutDurationAfterFence,
431                 [&syncFence, &errorStatus, &dispatchCallback](
432                         ErrorStatus error, const hidl_handle& handle,
433                         const sp<hal::IFencedExecutionCallback>& callback) {
434                     syncFence = handle;
435                     errorStatus = error;
436                     dispatchCallback = callback;
437                 });
438         if (!ret.isOk()) {
439             LOG(ERROR) << "executeFenced failure: " << ret.description();
440             return std::make_tuple(ANEURALNETWORKS_OP_FAILED, hal::hidl_handle(nullptr), nullptr,
441                                    timing);
442         }
443         if (errorStatus != ErrorStatus::NONE) {
444             LOG(ERROR) << "executeFenced returned "
445                        << toString(static_cast<ErrorStatus>(errorStatus));
446             return std::make_tuple(convertErrorStatusToResultCode(errorStatus),
447                                    hal::hidl_handle(nullptr), nullptr, timing);
448         }
449         return std::make_tuple(ANEURALNETWORKS_NO_ERROR, syncFence, dispatchCallback, timing);
450     }
451 
452     // fallback to synchronous execution if sync_fence is not supported
453     // first wait for all sync fences to be ready.
454     LOG(INFO) << "No drivers able to handle sync fences, falling back to regular execution";
455     for (const auto& fenceHandle : waitFor) {
456         if (!fenceHandle.getNativeHandle()) {
457             return std::make_tuple(ANEURALNETWORKS_BAD_DATA, hal::hidl_handle(nullptr), nullptr,
458                                    timing);
459         }
460         int syncFd = fenceHandle.getNativeHandle()->data[0];
461         if (syncFd <= 0) {
462             return std::make_tuple(ANEURALNETWORKS_BAD_DATA, hal::hidl_handle(nullptr), nullptr,
463                                    timing);
464         }
465         auto r = syncWait(syncFd, -1);
466         if (r != FenceState::SIGNALED) {
467             LOG(ERROR) << "syncWait failed, fd: " << syncFd;
468             return std::make_tuple(ANEURALNETWORKS_OP_FAILED, hal::hidl_handle(nullptr), nullptr,
469                                    timing);
470         }
471     }
472     int errorCode;
473     std::tie(errorCode, std::ignore, timing) =
474             executeSynchronously(request, measure, deadline, loopTimeoutDuration);
475     return std::make_tuple(errorCode, hal::hidl_handle(nullptr), nullptr, timing);
476 }
477 
getCapabilitiesFunction(V1_2::IDevice * device)478 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_2::IDevice* device) {
479     CHECK(device != nullptr);
480     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_2");
481     const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
482     std::pair<ErrorStatus, Capabilities> result = kFailure;
483     const Return<void> ret = device->getCapabilities_1_2(
484             [&result](V1_0::ErrorStatus error, const V1_2::Capabilities& capabilities) {
485                 result = std::make_pair(convertToV1_3(error), convertToV1_3(capabilities));
486             });
487     if (!ret.isOk()) {
488         LOG(ERROR) << "getCapabilities_1_2 failure: " << ret.description();
489         return kFailure;
490     }
491     return result;
492 }
493 
getCapabilitiesFunction(V1_1::IDevice * device)494 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_1::IDevice* device) {
495     CHECK(device != nullptr);
496     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_1");
497     const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
498     std::pair<ErrorStatus, Capabilities> result = kFailure;
499     const Return<void> ret = device->getCapabilities_1_1(
500             [&result](V1_0::ErrorStatus error, const V1_1::Capabilities& capabilities) {
501                 // Time taken to convert capabilities is trivial
502                 result = std::make_pair(convertToV1_3(error), convertToV1_3(capabilities));
503             });
504     if (!ret.isOk()) {
505         LOG(ERROR) << "getCapabilities_1_1 failure: " << ret.description();
506         return kFailure;
507     }
508     return result;
509 }
510 
getCapabilitiesFunction(V1_0::IDevice * device)511 static std::pair<ErrorStatus, Capabilities> getCapabilitiesFunction(V1_0::IDevice* device) {
512     CHECK(device != nullptr);
513     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities");
514     const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
515     std::pair<ErrorStatus, Capabilities> result = kFailure;
516     const Return<void> ret = device->getCapabilities(
517             [&result](V1_0::ErrorStatus error, const V1_0::Capabilities& capabilities) {
518                 // Time taken to convert capabilities is trivial
519                 result = std::make_pair(convertToV1_3(error), convertToV1_3(capabilities));
520             });
521     if (!ret.isOk()) {
522         LOG(ERROR) << "getCapabilities failure: " << ret.description();
523         return kFailure;
524     }
525     return result;
526 }
527 
getSupportedExtensionsFunction(V1_2::IDevice * device)528 static std::pair<ErrorStatus, hidl_vec<Extension>> getSupportedExtensionsFunction(
529         V1_2::IDevice* device) {
530     CHECK(device != nullptr);
531     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getSupportedExtensions");
532     const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
533     std::pair<ErrorStatus, hidl_vec<Extension>> result = kFailure;
534     const Return<void> ret = device->getSupportedExtensions(
535             [&result](V1_0::ErrorStatus error, const hidl_vec<Extension>& extensions) {
536                 result = std::make_pair(convertToV1_3(error), extensions);
537             });
538     if (!ret.isOk()) {
539         LOG(ERROR) << "getSupportedExtensions failure: " << ret.description();
540         return kFailure;
541     }
542     return result;
543 }
544 
getSupportedExtensionsFunction(V1_0::IDevice * device)545 static std::pair<ErrorStatus, hidl_vec<Extension>> getSupportedExtensionsFunction(
546         V1_0::IDevice* device) {
547     CHECK(device != nullptr);
548     return {ErrorStatus::NONE, {/* No extensions. */}};
549 }
550 
getTypeFunction(V1_2::IDevice * device)551 static int32_t getTypeFunction(V1_2::IDevice* device) {
552     CHECK(device != nullptr);
553     constexpr int32_t kFailure = -1;
554     int32_t result = kFailure;
555     const Return<void> ret =
556             device->getType([&result](V1_0::ErrorStatus error, DeviceType deviceType) {
557                 if (error == V1_0::ErrorStatus::NONE) {
558                     result = static_cast<int32_t>(deviceType);
559                 }
560             });
561     if (!ret.isOk()) {
562         LOG(ERROR) << "getType failure: " << ret.description();
563         return kFailure;
564     }
565     return result;
566 }
567 
getTypeFunction(V1_0::IDevice * device)568 static int32_t getTypeFunction(V1_0::IDevice* device) {
569     CHECK(device != nullptr);
570     return ANEURALNETWORKS_DEVICE_UNKNOWN;
571 }
572 
getVersionStringFunction(V1_2::IDevice * device)573 static std::pair<ErrorStatus, hidl_string> getVersionStringFunction(V1_2::IDevice* device) {
574     CHECK(device != nullptr);
575     const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""};
576     std::pair<ErrorStatus, hidl_string> result = kFailure;
577     const Return<void> ret = device->getVersionString(
578             [&result](V1_0::ErrorStatus error, const hidl_string& version) {
579                 result = std::make_pair(convertToV1_3(error), version);
580             });
581     if (!ret.isOk()) {
582         LOG(ERROR) << "getVersion failure: " << ret.description();
583         return kFailure;
584     }
585     return result;
586 }
587 
getVersionStringFunction(V1_0::IDevice * device)588 static std::pair<ErrorStatus, hidl_string> getVersionStringFunction(V1_0::IDevice* device) {
589     CHECK(device != nullptr);
590     return {ErrorStatus::NONE, "UNKNOWN"};
591 }
592 
getNumberOfCacheFilesNeededFunction(V1_2::IDevice * device)593 static std::tuple<ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeededFunction(
594         V1_2::IDevice* device) {
595     CHECK(device != nullptr);
596     constexpr std::tuple<ErrorStatus, uint32_t, uint32_t> kFailure = {ErrorStatus::GENERAL_FAILURE,
597                                                                       0, 0};
598     std::tuple<ErrorStatus, uint32_t, uint32_t> result = kFailure;
599     const Return<void> ret = device->getNumberOfCacheFilesNeeded(
600             [&result](V1_0::ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
601                 result = {convertToV1_3(error), numModelCache, numDataCache};
602             });
603     if (!ret.isOk()) {
604         LOG(ERROR) << "getNumberOfCacheFilesNeeded failure: " << ret.description();
605         return kFailure;
606     }
607     return result;
608 }
609 
getNumberOfCacheFilesNeededFunction(V1_0::IDevice * device)610 static std::tuple<ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeededFunction(
611         V1_0::IDevice* device) {
612     CHECK(device != nullptr);
613     return {ErrorStatus::NONE, 0, 0};
614 }
615 
616 struct InitialData {
617     hal::Capabilities capabilities;
618     hal::hidl_vec<hal::Extension> supportedExtensions;
619     int32_t type;
620     hal::hidl_string versionString;
621     std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded;
622 };
623 
624 template <typename Device>
initializeFunction(Device * device)625 static std::optional<InitialData> initializeFunction(Device* device) {
626     CHECK(device != nullptr);
627 
628     auto [capabilitiesStatus, capabilities] = getCapabilitiesFunction(device);
629     if (capabilitiesStatus != ErrorStatus::NONE) {
630         LOG(ERROR) << "IDevice::getCapabilities* returned the error "
631                    << toString(capabilitiesStatus);
632         return std::nullopt;
633     }
634     VLOG(MANAGER) << "Capab " << toString(capabilities);
635 
636     auto [versionStatus, versionString] = getVersionStringFunction(device);
637     if (versionStatus != ErrorStatus::NONE) {
638         LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(versionStatus);
639         return std::nullopt;
640     }
641 
642     const int32_t type = getTypeFunction(device);
643     if (type == -1) {
644         LOG(ERROR) << "IDevice::getType returned an error";
645         return std::nullopt;
646     }
647 
648     auto [extensionsStatus, supportedExtensions] = getSupportedExtensionsFunction(device);
649     if (extensionsStatus != ErrorStatus::NONE) {
650         LOG(ERROR) << "IDevice::getSupportedExtensions returned the error "
651                    << toString(extensionsStatus);
652         return std::nullopt;
653     }
654 
655     const auto [cacheFilesStatus, numModelCacheFiles, numDataCacheFiles] =
656             getNumberOfCacheFilesNeededFunction(device);
657     if (cacheFilesStatus != ErrorStatus::NONE) {
658         LOG(ERROR) << "IDevice::getNumberOfCacheFilesNeeded returned the error "
659                    << toString(cacheFilesStatus);
660         return std::nullopt;
661     }
662 
663     // The following limit is enforced by VTS
664     constexpr uint32_t maxNumCacheFiles =
665             static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES);
666     if (numModelCacheFiles > maxNumCacheFiles || numDataCacheFiles > maxNumCacheFiles) {
667         LOG(ERROR)
668                 << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files: "
669                    "numModelCacheFiles = "
670                 << numModelCacheFiles << ", numDataCacheFiles = " << numDataCacheFiles
671                 << ", maxNumCacheFiles = " << maxNumCacheFiles;
672         return std::nullopt;
673     }
674 
675     return InitialData{
676             /*.capabilities=*/std::move(capabilities),
677             /*.supportedExtensions=*/std::move(supportedExtensions),
678             /*.type=*/type,
679             /*.versionString=*/std::move(versionString),
680             /*.numberOfCacheFilesNeeded=*/{numModelCacheFiles, numDataCacheFiles},
681     };
682 }
683 
684 template <typename Core>
initialize(const Core & core)685 std::optional<InitialData> initialize(const Core& core) {
686     // version 1.3+ HAL
687     if (const auto device = core.template getDevice<V1_3::IDevice>()) {
688         return initializeFunction(device.get());
689     }
690 
691     // version 1.2 HAL
692     if (const auto device = core.template getDevice<V1_2::IDevice>()) {
693         return initializeFunction(device.get());
694     }
695 
696     // version 1.1 HAL
697     if (const auto device = core.template getDevice<V1_1::IDevice>()) {
698         return initializeFunction(device.get());
699     }
700 
701     // version 1.0 HAL
702     if (const auto device = core.template getDevice<V1_0::IDevice>()) {
703         return initializeFunction(device.get());
704     }
705 
706     // No device available
707     LOG(ERROR) << "Device not available!";
708     return std::nullopt;
709 }
710 
create(std::string serviceName,const DeviceFactory & makeDevice)711 std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
712                                                            const DeviceFactory& makeDevice) {
713     CHECK(makeDevice != nullptr)
714             << "VersionedIDevice::create passed invalid device factory object.";
715 
716     // get handle to IDevice object
717     sp<V1_0::IDevice> device = makeDevice(/*blocking=*/true);
718     if (device == nullptr) {
719         VLOG(DRIVER) << "VersionedIDevice::create got a null IDevice for " << serviceName;
720         return nullptr;
721     }
722 
723     auto core = Core::create(std::move(device));
724     if (!core.has_value()) {
725         LOG(ERROR) << "VersionedIDevice::create failed to create Core.";
726         return nullptr;
727     }
728 
729     auto initialData = initialize(*core);
730     if (!initialData.has_value()) {
731         LOG(ERROR) << "VersionedIDevice::create failed to initialize.";
732         return nullptr;
733     }
734 
735     auto [capabilities, supportedExtensions, type, versionString, numberOfCacheFilesNeeded] =
736             std::move(*initialData);
737     return std::make_shared<VersionedIDevice>(
738             std::move(capabilities), std::move(supportedExtensions), type, std::move(versionString),
739             numberOfCacheFilesNeeded, std::move(serviceName), makeDevice, std::move(core.value()));
740 }
741 
VersionedIDevice(hal::Capabilities capabilities,std::vector<hal::Extension> supportedExtensions,int32_t type,std::string versionString,std::pair<uint32_t,uint32_t> numberOfCacheFilesNeeded,std::string serviceName,const DeviceFactory & makeDevice,Core core)742 VersionedIDevice::VersionedIDevice(hal::Capabilities capabilities,
743                                    std::vector<hal::Extension> supportedExtensions, int32_t type,
744                                    std::string versionString,
745                                    std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
746                                    std::string serviceName, const DeviceFactory& makeDevice,
747                                    Core core)
748     : kCapabilities(std::move(capabilities)),
749       kSupportedExtensions(std::move(supportedExtensions)),
750       kType(type),
751       kVersionString(std::move(versionString)),
752       kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded),
753       kServiceName(std::move(serviceName)),
754       kMakeDevice(makeDevice),
755       mCore(std::move(core)) {}
756 
create(sp<V1_0::IDevice> device)757 std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
758     CHECK(device != nullptr) << "VersionedIDevice::Core::create passed invalid device object.";
759 
760     // create death handler object
761     sp<IDeviceDeathHandler> deathHandler = new IDeviceDeathHandler();
762 
763     // linkToDeath registers a callback that will be invoked on service death to
764     // proactively handle service crashes. If the linkToDeath call fails,
765     // asynchronous calls are susceptible to hangs if the service crashes before
766     // providing the response.
767     const Return<bool> ret = device->linkToDeath(deathHandler, 0);
768     if (!ret.isOk()) {
769         LOG(ERROR) << "VersionedIDevice::Core::create failed to register a death recipient for the "
770                       "IDevice object because of failure: "
771                    << ret.description();
772         return {};
773     }
774     if (ret != true) {
775         LOG(ERROR) << "VersionedIDevice::Core::create failed to register a death recipient for the "
776                       "IDevice object.";
777         return {};
778     }
779 
780     // return a valid Core object
781     return Core(std::move(device), std::move(deathHandler));
782 }
783 
784 // HIDL guarantees all V1_1 interfaces inherit from their corresponding V1_0 interfaces.
Core(sp<V1_0::IDevice> device,sp<IDeviceDeathHandler> deathHandler)785 VersionedIDevice::Core::Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler)
786     : mDeviceV1_0(std::move(device)),
787       mDeviceV1_1(V1_1::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
788       mDeviceV1_2(V1_2::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
789       mDeviceV1_3(V1_3::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
790       mDeathHandler(std::move(deathHandler)) {}
791 
~Core()792 VersionedIDevice::Core::~Core() {
793     if (mDeathHandler != nullptr) {
794         CHECK(mDeviceV1_0 != nullptr);
795         // It is safe to ignore any errors resulting from this unlinkToDeath call
796         // because the VersionedIDevice::Core object is already being destroyed and
797         // its underlying IDevice object is no longer being used by the NN runtime.
798         mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk();
799     }
800 }
801 
Core(Core && other)802 VersionedIDevice::Core::Core(Core&& other) noexcept
803     : mDeviceV1_0(std::move(other.mDeviceV1_0)),
804       mDeviceV1_1(std::move(other.mDeviceV1_1)),
805       mDeviceV1_2(std::move(other.mDeviceV1_2)),
806       mDeviceV1_3(std::move(other.mDeviceV1_3)),
807       mDeathHandler(std::move(other.mDeathHandler)) {
808     other.mDeathHandler = nullptr;
809 }
810 
operator =(Core && other)811 VersionedIDevice::Core& VersionedIDevice::Core::operator=(Core&& other) noexcept {
812     if (this != &other) {
813         mDeviceV1_0 = std::move(other.mDeviceV1_0);
814         mDeviceV1_1 = std::move(other.mDeviceV1_1);
815         mDeviceV1_2 = std::move(other.mDeviceV1_2);
816         mDeviceV1_3 = std::move(other.mDeviceV1_3);
817         mDeathHandler = std::move(other.mDeathHandler);
818         other.mDeathHandler = nullptr;
819     }
820     return *this;
821 }
822 
823 template <typename T_IDevice>
getDeviceAndDeathHandler() const824 std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> VersionedIDevice::Core::getDeviceAndDeathHandler()
825         const {
826     return {getDevice<T_IDevice>(), mDeathHandler};
827 }
828 
829 template <typename T_Return, typename T_IDevice, typename T_Callback>
callProtected(const char * context,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const sp<T_Callback> & callback,const sp<IDeviceDeathHandler> & deathHandler)830 Return<T_Return> callProtected(const char* context,
831                                const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
832                                const sp<T_IDevice>& device, const sp<T_Callback>& callback,
833                                const sp<IDeviceDeathHandler>& deathHandler) {
834     const auto scoped = deathHandler->protectCallback(callback);
835     Return<T_Return> ret = fn(device);
836     // Suppose there was a transport error.  We have the following cases:
837     // 1. Either not due to a dead device, or due to a device that was
838     //    already dead at the time of the call to protectCallback().  In
839     //    this case, the callback was never signalled.
840     // 2. Due to a device that died after the call to protectCallback() but
841     //    before fn() completed.  In this case, the callback was (or will
842     //    be) signalled by the deathHandler.
843     // Furthermore, what if there was no transport error, but the ErrorStatus is
844     // other than NONE?  We'll conservatively signal the callback anyway, just in
845     // case the driver was sloppy and failed to do so.
846     if (!ret.isOk() || ret != T_Return::NONE) {
847         // What if the deathHandler has signalled or will signal the callback?
848         // This is fine -- we're permitted to signal multiple times; and we're
849         // sending the same signal that the deathHandler does.
850         //
851         // What if the driver signalled the callback?  Then this signal is
852         // ignored.
853 
854         if (ret.isOk()) {
855             LOG(ERROR) << context << " returned " << toString(static_cast<T_Return>(ret));
856         } else {
857             LOG(ERROR) << context << " failure: " << ret.description();
858         }
859         sendFailureMessage(callback.get());
860     }
861     callback->wait();
862     return ret;
863 }
864 template <typename T_Return, typename T_IDevice>
callProtected(const char *,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const sp<T_IDevice> & device,const std::nullptr_t &,const sp<IDeviceDeathHandler> &)865 Return<T_Return> callProtected(const char*,
866                                const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
867                                const sp<T_IDevice>& device, const std::nullptr_t&,
868                                const sp<IDeviceDeathHandler>&) {
869     return fn(device);
870 }
871 
872 template <typename T_Return, typename T_IDevice, typename T_Callback>
recoverable(const char * context,const std::function<Return<T_Return> (const sp<T_IDevice> &)> & fn,const T_Callback & callback) const873 Return<T_Return> VersionedIDevice::recoverable(
874         const char* context, const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
875         const T_Callback& callback) const EXCLUDES(mMutex) {
876     CHECK_EQ(callback == nullptr, (std::is_same_v<T_Callback, std::nullptr_t>));
877 
878     sp<T_IDevice> device;
879     sp<IDeviceDeathHandler> deathHandler;
880     std::tie(device, deathHandler) = getDeviceAndDeathHandler<T_IDevice>();
881 
882     Return<T_Return> ret = callProtected(context, fn, device, callback, deathHandler);
883 
884     if (ret.isDeadObject()) {
885         {
886             std::unique_lock lock(mMutex);
887             // It's possible that another device has already done the recovery.
888             // It's harmless but wasteful for us to do so in this case.
889             auto pingReturn = mCore.getDevice<T_IDevice>()->ping();
890             if (pingReturn.isDeadObject()) {
891                 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering "
892                              << kServiceName;
893                 sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/false);
894                 if (recoveredDevice == nullptr) {
895                     VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for "
896                                  << kServiceName;
897                     return ret;
898                 }
899 
900                 auto core = Core::create(std::move(recoveredDevice));
901                 if (!core.has_value()) {
902                     LOG(ERROR) << "VersionedIDevice::recoverable failed to create Core.";
903                     return ret;
904                 }
905 
906                 mCore = std::move(core.value());
907             } else {
908                 VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context
909                              << ") -- Someone else recovered " << kServiceName;
910                 // Might still have a transport error, which we need to check
911                 // before pingReturn goes out of scope.
912                 (void)pingReturn.isOk();
913             }
914             std::tie(device, deathHandler) = mCore.getDeviceAndDeathHandler<T_IDevice>();
915         }
916         ret = callProtected(context, fn, device, callback, deathHandler);
917         // It's possible that the device died again, but we're only going to
918         // attempt recovery once per call to recoverable().
919     }
920     return ret;
921 }
922 
wait() const923 int VersionedIDevice::wait() const {
924     std::unique_lock lock(mMutex);
925     // It's possible that another device has already done the recovery.
926     // It's harmless but wasteful for us to do so in this case.
927     auto pingReturn = mCore.getDevice<V1_0::IDevice>()->ping();
928     if (pingReturn.isDeadObject()) {
929         VLOG(DRIVER) << "VersionedIDevice::wait -- Recovering " << kServiceName;
930         sp<V1_0::IDevice> recoveredDevice = kMakeDevice(/*blocking=*/true);
931         if (recoveredDevice == nullptr) {
932             LOG(ERROR) << "VersionedIDevice::wait got a null IDevice for " << kServiceName;
933             return ANEURALNETWORKS_OP_FAILED;
934         }
935 
936         auto core = Core::create(std::move(recoveredDevice));
937         if (!core.has_value()) {
938             LOG(ERROR) << "VersionedIDevice::wait failed to create Core.";
939             return ANEURALNETWORKS_OP_FAILED;
940         }
941 
942         mCore = std::move(core.value());
943     } else if (!pingReturn.isOk()) {
944         LOG(ERROR) << "VersionedIDevice::wait failed -- IDevice::ping returned "
945                    << pingReturn.description();
946         return ANEURALNETWORKS_OP_FAILED;
947     }
948 
949     return ANEURALNETWORKS_NO_ERROR;
950 }
951 
getCapabilities() const952 const Capabilities& VersionedIDevice::getCapabilities() const {
953     return kCapabilities;
954 }
955 
getSupportedExtensions() const956 const std::vector<Extension>& VersionedIDevice::getSupportedExtensions() const {
957     return kSupportedExtensions;
958 }
959 
getSupportedOperations(const MetaModel & metaModel) const960 std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
961         const MetaModel& metaModel) const {
962     const std::pair<ErrorStatus, hidl_vec<bool>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
963     std::pair<ErrorStatus, hidl_vec<bool>> result;
964 
965     const Model& model = metaModel.getModel();
966 
967     auto noneSupported = [&model] {
968         hidl_vec<bool> supported(model.main.operations.size());
969         std::fill(supported.begin(), supported.end(), false);
970         return std::make_pair(ErrorStatus::NONE, std::move(supported));
971     };
972 
973     auto remappedResult = [&model](const std::pair<ErrorStatus, hidl_vec<bool>>& result,
974                                    const std::function<uint32_t(uint32_t)>&
975                                            slicedModelOperationIndexToModelOperationIndex) {
976         const ErrorStatus status = result.first;
977         const hidl_vec<bool>& supported = result.second;
978         hidl_vec<bool> remappedSupported(model.main.operations.size());
979         std::fill(remappedSupported.begin(), remappedSupported.end(), false);
980         for (size_t i = 0; i < supported.size(); ++i) {
981             if (supported[i]) {
982                 remappedSupported[slicedModelOperationIndexToModelOperationIndex(i)] = true;
983             }
984         }
985         return std::make_pair(status, std::move(remappedSupported));
986     };
987 
988     // version 1.3+ HAL
989     if (getDevice<V1_3::IDevice>() != nullptr) {
990         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_3");
991         Return<void> ret = recoverable<void, V1_3::IDevice>(
992                 __FUNCTION__, [&model, &result](const sp<V1_3::IDevice>& device) {
993                     return device->getSupportedOperations_1_3(
994                             model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
995                                 result = std::make_pair(error, supported);
996                             });
997                 });
998         if (!ret.isOk()) {
999             LOG(ERROR) << "getSupportedOperations_1_3 failure: " << ret.description();
1000             return kFailure;
1001         }
1002         return result;
1003     }
1004 
1005     // version 1.2 HAL
1006     if (getDevice<V1_2::IDevice>() != nullptr) {
1007         const bool compliant = compliantWithV1_2(model);
1008         V1_2::Model model12;
1009         std::function<uint32_t(uint32_t)> slicedModelOperationIndexToModelOperationIndex;
1010         if (compliant) {
1011             model12 = convertToV1_2(model);
1012         } else {
1013             const auto slice12 = metaModel.getSliceV1_2();
1014             if (!slice12.has_value()) {
1015                 return noneSupported();
1016             }
1017             std::tie(model12, slicedModelOperationIndexToModelOperationIndex) = *slice12;
1018         }
1019         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_2");
1020         Return<void> ret = recoverable<void, V1_2::IDevice>(
1021                 __FUNCTION__, [&model12, &result](const sp<V1_2::IDevice>& device) {
1022                     return device->getSupportedOperations_1_2(
1023                             model12,
1024                             [&result](V1_0::ErrorStatus error, const hidl_vec<bool>& supported) {
1025                                 result = std::make_pair(convertToV1_3(error), supported);
1026                             });
1027                 });
1028         if (!ret.isOk()) {
1029             LOG(ERROR) << "getSupportedOperations_1_2 failure: " << ret.description();
1030             return kFailure;
1031         }
1032         if (!compliant) {
1033             return remappedResult(result, slicedModelOperationIndexToModelOperationIndex);
1034         }
1035         return result;
1036     }
1037 
1038     // version 1.1 HAL
1039     if (getDevice<V1_1::IDevice>() != nullptr) {
1040         const bool compliant = compliantWithV1_1(model);
1041         V1_1::Model model11;
1042         std::function<uint32_t(uint32_t)> slicedModelOperationIndexToModelOperationIndex;
1043         if (compliant) {
1044             model11 = convertToV1_1(model);
1045         } else {
1046             const auto slice11 = metaModel.getSliceV1_1();
1047             if (!slice11.has_value()) {
1048                 return noneSupported();
1049             }
1050             std::tie(model11, slicedModelOperationIndexToModelOperationIndex) = *slice11;
1051         }
1052         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_1");
1053         Return<void> ret = recoverable<void, V1_1::IDevice>(
1054                 __FUNCTION__, [&model11, &result](const sp<V1_1::IDevice>& device) {
1055                     return device->getSupportedOperations_1_1(
1056                             model11,
1057                             [&result](V1_0::ErrorStatus error, const hidl_vec<bool>& supported) {
1058                                 result = std::make_pair(convertToV1_3(error), supported);
1059                             });
1060                 });
1061         if (!ret.isOk()) {
1062             LOG(ERROR) << "getSupportedOperations_1_1 failure: " << ret.description();
1063             return kFailure;
1064         }
1065         if (!compliant) {
1066             return remappedResult(result, slicedModelOperationIndexToModelOperationIndex);
1067         }
1068         return result;
1069     }
1070 
1071     // version 1.0 HAL
1072     if (getDevice<V1_0::IDevice>() != nullptr) {
1073         const bool compliant = compliantWithV1_0(model);
1074         V1_0::Model model10;
1075         std::function<uint32_t(uint32_t)> slicedModelOperationIndexToModelOperationIndex;
1076         if (compliant) {
1077             model10 = convertToV1_0(model);
1078         } else {
1079             const auto slice10 = metaModel.getSliceV1_0();
1080             if (!slice10.has_value()) {
1081                 return noneSupported();
1082             }
1083             std::tie(model10, slicedModelOperationIndexToModelOperationIndex) = *slice10;
1084         }
1085         NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations");
1086         Return<void> ret = recoverable<void, V1_0::IDevice>(
1087                 __FUNCTION__, [&model10, &result](const sp<V1_0::IDevice>& device) {
1088                     return device->getSupportedOperations(
1089                             model10,
1090                             [&result](V1_0::ErrorStatus error, const hidl_vec<bool>& supported) {
1091                                 result = std::make_pair(convertToV1_3(error), supported);
1092                             });
1093                 });
1094         if (!ret.isOk()) {
1095             LOG(ERROR) << "getSupportedOperations failure: " << ret.description();
1096             return kFailure;
1097         }
1098         if (!compliant) {
1099             return remappedResult(result, slicedModelOperationIndexToModelOperationIndex);
1100         }
1101         return result;
1102     }
1103 
1104     // No device available
1105     LOG(ERROR) << "Device not available!";
1106     return kFailure;
1107 }
1108 
1109 // Opens cache file by filename and sets the handle to the opened fd. Returns false on fail. The
1110 // handle is expected to come in as empty, and is only set to a fd when the function returns true.
1111 // The file descriptor is always opened with both read and write permission.
createCacheHandle(const std::string & cache,bool createIfNotExist,hidl_handle * handle)1112 static bool createCacheHandle(const std::string& cache, bool createIfNotExist,
1113                               hidl_handle* handle) {
1114     CHECK(handle->getNativeHandle() == nullptr);
1115     int fd = open(cache.c_str(), createIfNotExist ? (O_RDWR | O_CREAT) : O_RDWR, S_IRUSR | S_IWUSR);
1116     NN_RET_CHECK_GE(fd, 0);
1117     native_handle_t* cacheNativeHandle = native_handle_create(1, 0);
1118     if (cacheNativeHandle == nullptr) {
1119         close(fd);
1120         return false;
1121     }
1122     cacheNativeHandle->data[0] = fd;
1123     handle->setTo(cacheNativeHandle, /*shouldOwn=*/true);
1124     return true;
1125 }
1126 
1127 // Opens a list of cache files and returns the handle vector. Returns empty vector on fail.
1128 // The file descriptors are always opened with both read and write permission.
createCacheHandleVec(uint32_t numCacheFiles,const std::string & baseFileName,bool createIfNotExist)1129 static hidl_vec<hidl_handle> createCacheHandleVec(uint32_t numCacheFiles,
1130                                                   const std::string& baseFileName,
1131                                                   bool createIfNotExist) {
1132     CHECK(numCacheFiles <= static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES));
1133     hidl_vec<hidl_handle> handles(numCacheFiles);
1134     for (uint32_t i = 0; i < numCacheFiles; i++) {
1135         std::string filename = baseFileName + std::to_string(i);
1136         VLOG(COMPILATION) << "Cache " << i << ": " << filename;
1137         if (!createCacheHandle(filename, createIfNotExist, &handles[i])) {
1138             return hidl_vec<hidl_handle>();
1139         }
1140     }
1141     return handles;
1142 }
1143 
1144 // Maps token to cache file names and sets the handle vectors to the opened fds. Returns false on
1145 // fail and leaves the vectors empty. Each vector is expected to come in as empty.
getCacheHandles(const std::string & cacheDir,const CacheToken & token,const std::pair<uint32_t,uint32_t> & numCacheFiles,bool createIfNotExist,hidl_vec<hidl_handle> * modelCache,hidl_vec<hidl_handle> * dataCache)1146 static bool getCacheHandles(const std::string& cacheDir, const CacheToken& token,
1147                             const std::pair<uint32_t, uint32_t>& numCacheFiles,
1148                             bool createIfNotExist, hidl_vec<hidl_handle>* modelCache,
1149                             hidl_vec<hidl_handle>* dataCache) {
1150     // The filename includes ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2 characters for token,
1151     // and 1 character for model/data cache identifier.
1152     std::string filename(ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2 + 1, '0');
1153     for (uint32_t i = 0; i < ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN; i++) {
1154         filename[i * 2] = 'A' + (token[i] & 0x0F);
1155         filename[i * 2 + 1] = 'A' + (token[i] >> 4);
1156     }
1157     CHECK(cacheDir.empty() || cacheDir.back() == '/');
1158     std::string cacheFileName = cacheDir + filename;
1159 
1160     const uint32_t cacheTypeIdentifierIndex =
1161             cacheDir.size() + ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN * 2;
1162     cacheFileName[cacheTypeIdentifierIndex] = '1';
1163     *modelCache = createCacheHandleVec(numCacheFiles.first, cacheFileName, createIfNotExist);
1164     if (modelCache->size() != numCacheFiles.first) {
1165         return false;
1166     }
1167     cacheFileName[cacheTypeIdentifierIndex] = '2';
1168     *dataCache = createCacheHandleVec(numCacheFiles.second, cacheFileName, createIfNotExist);
1169     if (dataCache->size() != numCacheFiles.second) {
1170         modelCache->resize(0);
1171         return false;
1172     }
1173     return true;
1174 }
1175 
prepareModelFailure(ErrorStatus status=ErrorStatus::GENERAL_FAILURE)1176 static std::pair<int, std::shared_ptr<VersionedIPreparedModel>> prepareModelFailure(
1177         ErrorStatus status = ErrorStatus::GENERAL_FAILURE) {
1178     return {convertErrorStatusToResultCode(status), nullptr};
1179 }
1180 
prepareModelResult(const PreparedModelCallback & callback,const char * prepareName,const std::string & serviceName)1181 static std::pair<int, std::shared_ptr<VersionedIPreparedModel>> prepareModelResult(
1182         const PreparedModelCallback& callback, const char* prepareName,
1183         const std::string& serviceName) {
1184     callback.wait();
1185     if (callback.isDeadObject()) {
1186         LOG(ERROR) << prepareName << " on " << serviceName
1187                    << " failed because the PreparedModel object is dead";
1188         return {ANEURALNETWORKS_DEAD_OBJECT, nullptr};
1189     }
1190     const ErrorStatus status = callback.getStatus();
1191     const sp<V1_0::IPreparedModel> preparedModel = callback.getPreparedModel();
1192 
1193     if (status != ErrorStatus::NONE) {
1194         LOG(ERROR) << prepareName << " on " << serviceName << " failed: "
1195                    << "prepareReturnStatus=" << toString(status);
1196         return prepareModelFailure(status);
1197     }
1198     if (preparedModel == nullptr) {
1199         LOG(ERROR) << prepareName << " on " << serviceName << " failed: preparedModel is nullptr";
1200         return prepareModelFailure();
1201     }
1202 
1203     return makeVersionedIPreparedModel(preparedModel);
1204 }
1205 
prepareModelInternal(const Model & model,ExecutionPreference preference,Priority priority,const std::optional<Deadline> & deadline,const std::string & cacheDir,const std::optional<CacheToken> & maybeToken) const1206 std::pair<int, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevice::prepareModelInternal(
1207         const Model& model, ExecutionPreference preference, Priority priority,
1208         const std::optional<Deadline>& deadline, const std::string& cacheDir,
1209         const std::optional<CacheToken>& maybeToken) const {
1210     // Note that some work within VersionedIDevice will be subtracted from the IPC layer
1211     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "prepareModel");
1212     const std::pair<int, std::shared_ptr<VersionedIPreparedModel>> kDeadObject = {
1213             ANEURALNETWORKS_DEAD_OBJECT, nullptr};
1214 
1215     // Get cache files if they exist, otherwise create them.
1216     hidl_vec<hidl_handle> modelCache, dataCache;
1217     if (!maybeToken.has_value() ||
1218         !getCacheHandles(cacheDir, *maybeToken, kNumberOfCacheFilesNeeded,
1219                          /*createIfNotExist=*/true, &modelCache, &dataCache)) {
1220         modelCache.resize(0);
1221         dataCache.resize(0);
1222     }
1223 
1224     // Get the token if it exists, otherwise get a null token.
1225     static const CacheToken kNullToken{};
1226     const CacheToken token = maybeToken.value_or(kNullToken);
1227 
1228     const sp<PreparedModelCallback> callback = new PreparedModelCallback();
1229 
1230     // If 1.3 device, try preparing model
1231     if (getDevice<V1_3::IDevice>() != nullptr) {
1232         const auto otp = makeTimePoint(deadline);
1233         const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_3::IDevice>(
1234                 __FUNCTION__,
1235                 [&model, preference, priority, &otp, &modelCache, &dataCache, &token,
1236                  &callback](const sp<V1_3::IDevice>& device) {
1237                     return device->prepareModel_1_3(model, preference, priority, otp, modelCache,
1238                                                     dataCache, token, callback);
1239                 },
1240                 callback);
1241         if (ret.isDeadObject()) {
1242             LOG(ERROR) << "prepareModel_1_3 failure: " << ret.description();
1243             return kDeadObject;
1244         }
1245         if (!ret.isOk()) {
1246             LOG(ERROR) << "prepareModel_1_3 failure: " << ret.description();
1247             return prepareModelFailure();
1248         }
1249         if (ret != ErrorStatus::NONE) {
1250             LOG(ERROR) << "prepareModel_1_3 returned " << toString(static_cast<ErrorStatus>(ret));
1251             return prepareModelFailure(ret);
1252         }
1253         return prepareModelResult(*callback, "prepareModel_1_3", kServiceName);
1254     }
1255 
1256     // If 1.2 device, try preparing model (requires conversion)
1257     if (getDevice<V1_2::IDevice>() != nullptr) {
1258         bool compliant = false;
1259         V1_2::Model model12;
1260         {
1261             // Attribute time spent in model inspection and conversion to
1262             // Runtime, as the time may be substantial (0.03ms for mobilenet,
1263             // but could be larger for other models).
1264             NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
1265                                   "VersionedIDevice::prepareModel_1_2");
1266             compliant = compliantWithV1_2(model);
1267             if (compliant) {
1268                 model12 = convertToV1_2(model);  // copy is elided
1269             }
1270         }
1271         if (compliant) {
1272             const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_2::IDevice>(
1273                     __FUNCTION__,
1274                     [&model12, &preference, &modelCache, &dataCache, &token,
1275                      &callback](const sp<V1_2::IDevice>& device) {
1276                         return device->prepareModel_1_2(model12, preference, modelCache, dataCache,
1277                                                         token, callback);
1278                     },
1279                     callback);
1280             if (ret.isDeadObject()) {
1281                 LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
1282                 return kDeadObject;
1283             }
1284             if (!ret.isOk()) {
1285                 LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
1286                 return prepareModelFailure();
1287             }
1288             const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1289             if (status != V1_0::ErrorStatus::NONE) {
1290                 LOG(ERROR) << "prepareModel_1_2 returned " << toString(status);
1291                 return prepareModelFailure(convertToV1_3(status));
1292             }
1293             return prepareModelResult(*callback, "prepareModel_1_2", kServiceName);
1294         }
1295 
1296         LOG(ERROR) << "Could not handle prepareModel_1_2!";
1297         return prepareModelFailure();
1298     }
1299 
1300     // If 1.1 device, try preparing model (requires conversion)
1301     if (getDevice<V1_1::IDevice>() != nullptr) {
1302         bool compliant = false;
1303         V1_1::Model model11;
1304         {
1305             // Attribute time spent in model inspection and conversion to
1306             // Runtime, as the time may be substantial (0.03ms for mobilenet,
1307             // but could be larger for other models).
1308             NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
1309                                   "VersionedIDevice::prepareModel_1_1");
1310             compliant = compliantWithV1_1(model);
1311             if (compliant) {
1312                 model11 = convertToV1_1(model);  // copy is elided
1313             }
1314         }
1315         if (compliant) {
1316             const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_1::IDevice>(
1317                     __FUNCTION__,
1318                     [&model11, &preference, &callback](const sp<V1_1::IDevice>& device) {
1319                         return device->prepareModel_1_1(model11, preference, callback);
1320                     },
1321                     callback);
1322             if (ret.isDeadObject()) {
1323                 LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
1324                 return kDeadObject;
1325             }
1326             if (!ret.isOk()) {
1327                 LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
1328                 return prepareModelFailure();
1329             }
1330             const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1331             if (status != V1_0::ErrorStatus::NONE) {
1332                 LOG(ERROR) << "prepareModel_1_1 returned " << toString(status);
1333                 return prepareModelFailure(convertToV1_3(status));
1334             }
1335             return prepareModelResult(*callback, "prepareModel_1_1", kServiceName);
1336         }
1337 
1338         LOG(ERROR) << "Could not handle prepareModel_1_1!";
1339         return prepareModelFailure();
1340     }
1341 
1342     // If 1.0 device, try preparing model (requires conversion)
1343     if (getDevice<V1_0::IDevice>() != nullptr) {
1344         bool compliant = false;
1345         V1_0::Model model10;
1346         {
1347             // Attribute time spent in model inspection and conversion to
1348             // Runtime, as the time may be substantial (0.03ms for mobilenet,
1349             // but could be larger for other models).
1350             NNTRACE_FULL_SUBTRACT(NNTRACE_LAYER_RUNTIME, NNTRACE_PHASE_COMPILATION,
1351                                   "VersionedIDevice::prepareModel");
1352             compliant = compliantWithV1_0(model);
1353             if (compliant) {
1354                 model10 = convertToV1_0(model);  // copy is elided
1355             }
1356         }
1357         if (compliant) {
1358             const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_0::IDevice>(
1359                     __FUNCTION__,
1360                     [&model10, &callback](const sp<V1_0::IDevice>& device) {
1361                         return device->prepareModel(model10, callback);
1362                     },
1363                     callback);
1364             if (ret.isDeadObject()) {
1365                 LOG(ERROR) << "prepareModel failure: " << ret.description();
1366                 return kDeadObject;
1367             }
1368             if (!ret.isOk()) {
1369                 LOG(ERROR) << "prepareModel failure: " << ret.description();
1370                 return prepareModelFailure();
1371             }
1372             const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1373             if (status != V1_0::ErrorStatus::NONE) {
1374                 LOG(ERROR) << "prepareModel returned " << toString(status);
1375                 return prepareModelFailure(convertToV1_3(status));
1376             }
1377             return prepareModelResult(*callback, "prepareModel", kServiceName);
1378         }
1379 
1380         LOG(ERROR) << "Could not handle prepareModel!";
1381         return prepareModelFailure();
1382     }
1383 
1384     // Return error because there is no valid device
1385     LOG(ERROR) << "prepareModel called with no device";
1386     return prepareModelFailure();
1387 }
1388 
1389 std::pair<int, std::shared_ptr<VersionedIPreparedModel>>
prepareModelFromCacheInternal(const std::optional<Deadline> & deadline,const std::string & cacheDir,const CacheToken & token) const1390 VersionedIDevice::prepareModelFromCacheInternal(const std::optional<Deadline>& deadline,
1391                                                 const std::string& cacheDir,
1392                                                 const CacheToken& token) const {
1393     // Note that some work within VersionedIDevice will be subtracted from the IPC layer
1394     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "prepareModelFromCache");
1395     VLOG(COMPILATION) << "prepareModelFromCache";
1396     const std::pair<int, std::shared_ptr<VersionedIPreparedModel>> kDeadObject = {
1397             ANEURALNETWORKS_DEAD_OBJECT, nullptr};
1398 
1399     // Get cache files if they exist, otherwise return from the function early.
1400     hidl_vec<hidl_handle> modelCache, dataCache;
1401     if (!getCacheHandles(cacheDir, token, kNumberOfCacheFilesNeeded,
1402                          /*createIfNotExist=*/false, &modelCache, &dataCache)) {
1403         return prepareModelFailure();
1404     }
1405 
1406     // version 1.3+ HAL
1407     if (getDevice<V1_3::IDevice>() != nullptr) {
1408         const auto otp = makeTimePoint(deadline);
1409         const sp<PreparedModelCallback> callback = new PreparedModelCallback();
1410         const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_3::IDevice>(
1411                 __FUNCTION__,
1412                 [&otp, &modelCache, &dataCache, &token,
1413                  &callback](const sp<V1_3::IDevice>& device) {
1414                     return device->prepareModelFromCache_1_3(otp, modelCache, dataCache, token,
1415                                                              callback);
1416                 },
1417                 callback);
1418         if (ret.isDeadObject()) {
1419             LOG(ERROR) << "prepareModelFromCache_1_3 failure: " << ret.description();
1420             return kDeadObject;
1421         }
1422         if (!ret.isOk()) {
1423             LOG(ERROR) << "prepareModelFromCache_1_3 failure: " << ret.description();
1424             return prepareModelFailure();
1425         }
1426         if (ret != ErrorStatus::NONE) {
1427             LOG(ERROR) << "prepareModelFromCache_1_3 returned "
1428                        << toString(static_cast<ErrorStatus>(ret));
1429             return prepareModelFailure(ret);
1430         }
1431         return prepareModelResult(*callback, "prepareModelFromCache_1_3", kServiceName);
1432     }
1433 
1434     // version 1.2 HAL
1435     if (getDevice<V1_2::IDevice>() != nullptr) {
1436         const sp<PreparedModelCallback> callback = new PreparedModelCallback();
1437         const Return<V1_0::ErrorStatus> ret = recoverable<V1_0::ErrorStatus, V1_2::IDevice>(
1438                 __FUNCTION__,
1439                 [&modelCache, &dataCache, &token, &callback](const sp<V1_2::IDevice>& device) {
1440                     return device->prepareModelFromCache(modelCache, dataCache, token, callback);
1441                 },
1442                 callback);
1443         if (ret.isDeadObject()) {
1444             LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
1445             return kDeadObject;
1446         }
1447         if (!ret.isOk()) {
1448             LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
1449             return prepareModelFailure();
1450         }
1451         const V1_0::ErrorStatus status = static_cast<V1_0::ErrorStatus>(ret);
1452         if (status != V1_0::ErrorStatus::NONE) {
1453             LOG(ERROR) << "prepareModelFromCache returned " << toString(status);
1454             return prepareModelFailure(convertToV1_3(status));
1455         }
1456         return prepareModelResult(*callback, "prepareModelFromCache", kServiceName);
1457     }
1458 
1459     // version too low
1460     if (getDevice<V1_0::IDevice>() != nullptr) {
1461         LOG(ERROR) << "prepareModelFromCache called on V1_1 or V1_0 device";
1462         return prepareModelFailure();
1463     }
1464 
1465     // No device available
1466     LOG(ERROR) << "prepareModelFromCache called with no device";
1467     return prepareModelFailure();
1468 }
1469 
prepareModel(const ModelFactory & makeModel,ExecutionPreference preference,Priority priority,const std::optional<Deadline> & deadline,const std::string & cacheDir,const std::optional<CacheToken> & maybeToken) const1470 std::pair<int, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevice::prepareModel(
1471         const ModelFactory& makeModel, ExecutionPreference preference, Priority priority,
1472         const std::optional<Deadline>& deadline, const std::string& cacheDir,
1473         const std::optional<CacheToken>& maybeToken) const {
1474     // Attempt to compile from cache if token is present.
1475     if (maybeToken.has_value()) {
1476         const auto [n, preparedModel] =
1477                 prepareModelFromCacheInternal(deadline, cacheDir, *maybeToken);
1478         if (n == ANEURALNETWORKS_NO_ERROR) {
1479             return {n, preparedModel};
1480         }
1481     }
1482 
1483     // Fallback to full compilation (possibly with token) if
1484     // prepareModelFromCache could not be used or failed.
1485     const Model model = makeModel();
1486     return prepareModelInternal(model, preference, priority, deadline, cacheDir, maybeToken);
1487 }
1488 
getFeatureLevel() const1489 int64_t VersionedIDevice::getFeatureLevel() const {
1490     constexpr int64_t kFailure = -1;
1491 
1492     if (getDevice<V1_3::IDevice>() != nullptr) {
1493         return __ANDROID_API_R__;
1494     } else if (getDevice<V1_2::IDevice>() != nullptr) {
1495         return __ANDROID_API_Q__;
1496     } else if (getDevice<V1_1::IDevice>() != nullptr) {
1497         return __ANDROID_API_P__;
1498     } else if (getDevice<V1_0::IDevice>() != nullptr) {
1499         return __ANDROID_API_O_MR1__;
1500     } else {
1501         LOG(ERROR) << "Device not available!";
1502         return kFailure;
1503     }
1504 }
1505 
getType() const1506 int32_t VersionedIDevice::getType() const {
1507     return kType;
1508 }
1509 
getVersionString() const1510 const std::string& VersionedIDevice::getVersionString() const {
1511     return kVersionString;
1512 }
1513 
getNumberOfCacheFilesNeeded() const1514 std::pair<uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFilesNeeded() const {
1515     return kNumberOfCacheFilesNeeded;
1516 }
1517 
getName() const1518 const std::string& VersionedIDevice::getName() const {
1519     return kServiceName;
1520 }
1521 
allocate(const BufferDesc & desc,const std::vector<std::shared_ptr<VersionedIPreparedModel>> & versionedPreparedModels,const hidl_vec<BufferRole> & inputRoles,const hidl_vec<BufferRole> & outputRoles) const1522 std::tuple<ErrorStatus, sp<IBuffer>, uint32_t> VersionedIDevice::allocate(
1523         const BufferDesc& desc,
1524         const std::vector<std::shared_ptr<VersionedIPreparedModel>>& versionedPreparedModels,
1525         const hidl_vec<BufferRole>& inputRoles, const hidl_vec<BufferRole>& outputRoles) const {
1526     const auto kFailure = std::make_tuple<ErrorStatus, sp<IBuffer>, uint32_t>(
1527             ErrorStatus::GENERAL_FAILURE, nullptr, 0);
1528 
1529     // version 1.3+ HAL
1530     if (getDevice<V1_3::IDevice>() != nullptr) {
1531         hidl_vec<sp<V1_3::IPreparedModel>> preparedModels(versionedPreparedModels.size());
1532         std::transform(versionedPreparedModels.begin(), versionedPreparedModels.end(),
1533                        preparedModels.begin(),
1534                        [](const auto& preparedModel) { return preparedModel->getV1_3(); });
1535 
1536         std::tuple<ErrorStatus, sp<IBuffer>, int32_t> result;
1537         const Return<void> ret = recoverable<void, V1_3::IDevice>(
1538                 __FUNCTION__, [&](const sp<V1_3::IDevice>& device) {
1539                     return device->allocate(desc, preparedModels, inputRoles, outputRoles,
1540                                             [&result](ErrorStatus error, const sp<IBuffer>& buffer,
1541                                                       uint32_t token) {
1542                                                 result = {error, buffer, token};
1543                                             });
1544                 });
1545         if (!ret.isOk()) {
1546             LOG(ERROR) << "allocate failure: " << ret.description();
1547             return kFailure;
1548         }
1549         return result;
1550     }
1551 
1552     // version too low or no device available
1553     LOG(ERROR) << "Could not handle allocate";
1554     return kFailure;
1555 }
1556 
1557 }  // namespace nn
1558 }  // namespace android
1559