1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "SampleDriver"
18 
19 #include "SampleDriver.h"
20 
21 #include <android-base/logging.h>
22 #include <android-base/properties.h>
23 #include <hidl/LegacySupport.h>
24 
25 #include <algorithm>
26 #include <chrono>
27 #include <map>
28 #include <memory>
29 #include <optional>
30 #include <set>
31 #include <thread>
32 #include <tuple>
33 #include <utility>
34 #include <vector>
35 
36 #include "BufferTracker.h"
37 #include "CpuExecutor.h"
38 #include "ExecutionBurstServer.h"
39 #include "HalInterfaces.h"
40 #include "SampleDriverUtils.h"
41 #include "Tracing.h"
42 #include "ValidateHal.h"
43 
44 namespace android {
45 namespace nn {
46 namespace sample_driver {
47 
48 namespace {
49 
50 using namespace hal;
51 
52 using time_point = std::chrono::steady_clock::time_point;
53 
now()54 auto now() {
55     return std::chrono::steady_clock::now();
56 };
57 
microsecondsDuration(decltype(now ()) end,decltype(now ()) start)58 auto microsecondsDuration(decltype(now()) end, decltype(now()) start) {
59     return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
60 };
61 
62 }  // namespace
63 
64 static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
65 
getCapabilities(getCapabilities_cb cb)66 Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
67     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
68                  "SampleDriver::getCapabilities");
69     return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
70         // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
71         cb(convertToV1_0(error), convertToV1_0(capabilities));
72     });
73 }
74 
getCapabilities_1_1(getCapabilities_1_1_cb cb)75 Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
76     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
77                  "SampleDriver::getCapabilities_1_1");
78     return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
79         // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
80         cb(convertToV1_0(error), convertToV1_1(capabilities));
81     });
82 }
83 
getCapabilities_1_2(getCapabilities_1_2_cb cb)84 Return<void> SampleDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
85     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
86                  "SampleDriver::getCapabilities_1_2");
87     return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
88         // TODO(dgross): Do we need to check compliantWithV1_2(capabilities)?
89         cb(convertToV1_0(error), convertToV1_2(capabilities));
90     });
91 }
92 
getVersionString(getVersionString_cb cb)93 Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
94     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
95                  "SampleDriver::getVersionString");
96     cb(V1_0::ErrorStatus::NONE, "JUST_AN_EXAMPLE");
97     return Void();
98 }
99 
getType(getType_cb cb)100 Return<void> SampleDriver::getType(getType_cb cb) {
101     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
102     cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
103     return Void();
104 }
105 
getSupportedExtensions(getSupportedExtensions_cb cb)106 Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
107     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
108                  "SampleDriver::getSupportedExtensions");
109     cb(V1_0::ErrorStatus::NONE, {/* No extensions. */});
110     return Void();
111 }
112 
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)113 Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
114                                                   getSupportedOperations_cb cb) {
115     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
116                  "SampleDriver::getSupportedOperations");
117     if (!validateModel(model)) {
118         VLOG(DRIVER) << "getSupportedOperations";
119         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
120         return Void();
121     }
122     return getSupportedOperations_1_3(convertToV1_3(model),
123                                       [&](ErrorStatus status, const hidl_vec<bool>& supported) {
124                                           cb(convertToV1_0(status), supported);
125                                       });
126 }
127 
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)128 Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
129                                                       getSupportedOperations_1_1_cb cb) {
130     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
131                  "SampleDriver::getSupportedOperations_1_1");
132     if (!validateModel(model)) {
133         VLOG(DRIVER) << "getSupportedOperations_1_1";
134         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
135         return Void();
136     }
137     return getSupportedOperations_1_3(convertToV1_3(model),
138                                       [&](ErrorStatus status, const hidl_vec<bool>& supported) {
139                                           cb(convertToV1_0(status), supported);
140                                       });
141 }
142 
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)143 Return<void> SampleDriver::getSupportedOperations_1_2(const V1_2::Model& model,
144                                                       getSupportedOperations_1_2_cb cb) {
145     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
146                  "SampleDriver::getSupportedOperations_1_2");
147     if (!validateModel(model)) {
148         VLOG(DRIVER) << "getSupportedOperations_1_2";
149         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
150         return Void();
151     }
152     return getSupportedOperations_1_3(convertToV1_3(model),
153                                       [&](ErrorStatus status, const hidl_vec<bool>& supported) {
154                                           cb(convertToV1_0(status), supported);
155                                       });
156 }
157 
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)158 Return<void> SampleDriver::getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) {
159     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
160                  "SampleDriver::getNumberOfCacheFilesNeeded");
161     // Set both numbers to be 0 for cache not supported.
162     cb(V1_0::ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
163     return Void();
164 }
165 
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)166 Return<V1_0::ErrorStatus> SampleDriver::prepareModel(
167         const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) {
168     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
169     const ErrorStatus status = prepareModelBase(
170             model, this, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, {}, callback);
171     return convertToV1_0(status);
172 }
173 
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)174 Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_1(
175         const V1_1::Model& model, ExecutionPreference preference,
176         const sp<V1_0::IPreparedModelCallback>& callback) {
177     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
178     const ErrorStatus status =
179             prepareModelBase(model, this, preference, kDefaultPriority, {}, callback);
180     return convertToV1_0(status);
181 }
182 
prepareModel_1_2(const V1_2::Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)183 Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_2(
184         const V1_2::Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>&,
185         const hidl_vec<hidl_handle>&, const CacheToken&,
186         const sp<V1_2::IPreparedModelCallback>& callback) {
187     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
188     const ErrorStatus status =
189             prepareModelBase(model, this, preference, kDefaultPriority, {}, callback);
190     return convertToV1_0(status);
191 }
192 
prepareModel_1_3(const V1_3::Model & model,ExecutionPreference preference,Priority priority,const OptionalTimePoint & deadline,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)193 Return<V1_3::ErrorStatus> SampleDriver::prepareModel_1_3(
194         const V1_3::Model& model, ExecutionPreference preference, Priority priority,
195         const OptionalTimePoint& deadline, const hidl_vec<hidl_handle>&,
196         const hidl_vec<hidl_handle>&, const CacheToken&,
197         const sp<V1_3::IPreparedModelCallback>& callback) {
198     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_3");
199     return prepareModelBase(model, this, preference, priority, deadline, callback);
200 }
201 
prepareModelFromCache(const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)202 Return<V1_0::ErrorStatus> SampleDriver::prepareModelFromCache(
203         const hidl_vec<hidl_handle>&, const hidl_vec<hidl_handle>&, const CacheToken&,
204         const sp<V1_2::IPreparedModelCallback>& callback) {
205     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
206                  "SampleDriver::prepareModelFromCache");
207     notify(callback, ErrorStatus::GENERAL_FAILURE, nullptr);
208     return V1_0::ErrorStatus::GENERAL_FAILURE;
209 }
210 
prepareModelFromCache_1_3(const OptionalTimePoint &,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)211 Return<ErrorStatus> SampleDriver::prepareModelFromCache_1_3(
212         const OptionalTimePoint& /*deadline*/, const hidl_vec<hidl_handle>&,
213         const hidl_vec<hidl_handle>&, const CacheToken&,
214         const sp<V1_3::IPreparedModelCallback>& callback) {
215     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
216                  "SampleDriver::prepareModelFromCache_1_3");
217     notify(callback, ErrorStatus::GENERAL_FAILURE, nullptr);
218     return ErrorStatus::GENERAL_FAILURE;
219 }
220 
getStatus()221 Return<DeviceStatus> SampleDriver::getStatus() {
222     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
223     VLOG(DRIVER) << "getStatus()";
224     return DeviceStatus::AVAILABLE;
225 }
226 
227 // Safely downcast an IPreparedModel object to SamplePreparedModel.
228 // This function will return nullptr if the IPreparedModel object is not originated from the sample
229 // driver process.
castToSamplePreparedModel(const sp<IPreparedModel> & preparedModel)230 static const SamplePreparedModel* castToSamplePreparedModel(
231         const sp<IPreparedModel>& preparedModel) {
232     if (preparedModel->isRemote()) {
233         return nullptr;
234     } else {
235         // This static_cast is safe because SamplePreparedModel is the only class that implements
236         // the IPreparedModel interface in the sample driver process.
237         return static_cast<const SamplePreparedModel*>(preparedModel.get());
238     }
239 }
240 
allocate(const V1_3::BufferDesc & desc,const hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hidl_vec<V1_3::BufferRole> & inputRoles,const hidl_vec<V1_3::BufferRole> & outputRoles,allocate_cb cb)241 Return<void> SampleDriver::allocate(const V1_3::BufferDesc& desc,
242                                     const hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
243                                     const hidl_vec<V1_3::BufferRole>& inputRoles,
244                                     const hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) {
245     constexpr uint32_t kInvalidBufferToken = 0;
246 
247     VLOG(DRIVER) << "SampleDriver::allocate";
248     std::set<PreparedModelRole> roles;
249     V1_3::Operand operand;
250     auto getModel = [](const sp<V1_3::IPreparedModel>& preparedModel) -> const V1_3::Model* {
251         const auto* samplePreparedModel = castToSamplePreparedModel(preparedModel);
252         if (samplePreparedModel == nullptr) {
253             LOG(ERROR) << "SampleDriver::allocate -- unknown remote IPreparedModel.";
254             return nullptr;
255         }
256         return samplePreparedModel->getModel();
257     };
258     if (!validateMemoryDesc(desc, preparedModels, inputRoles, outputRoles, getModel, &roles,
259                             &operand)) {
260         LOG(ERROR) << "SampleDriver::allocate -- validation failed.";
261         cb(ErrorStatus::INVALID_ARGUMENT, nullptr, kInvalidBufferToken);
262         return Void();
263     }
264 
265     if (isExtensionOperandType(operand.type)) {
266         LOG(ERROR) << "SampleDriver::allocate -- does not support extension type.";
267         cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
268         return Void();
269     }
270 
271     // TODO(xusongw): Support allocating buffers with unknown dimensions or rank.
272     uint32_t size = nonExtensionOperandSizeOfData(operand.type, operand.dimensions);
273     VLOG(DRIVER) << "SampleDriver::allocate -- type = " << toString(operand.type)
274                  << ", dimensions = " << toString(operand.dimensions) << ", size = " << size;
275     if (size == 0) {
276         LOG(ERROR) << "SampleDriver::allocate -- does not support dynamic output shape.";
277         cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
278         return Void();
279     }
280 
281     auto bufferWrapper = ManagedBuffer::create(size, std::move(roles), std::move(operand));
282     if (bufferWrapper == nullptr) {
283         LOG(ERROR) << "SampleDriver::allocate -- not enough memory.";
284         cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
285         return Void();
286     }
287 
288     auto token = mBufferTracker->add(bufferWrapper);
289     if (token == nullptr) {
290         LOG(ERROR) << "SampleDriver::allocate -- BufferTracker returned invalid token.";
291         cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
292         return Void();
293     }
294 
295     const uint32_t tokenValue = token->get();
296     sp<SampleBuffer> sampleBuffer = new SampleBuffer(std::move(bufferWrapper), std::move(token));
297     VLOG(DRIVER) << "SampleDriver::allocate -- successfully allocates the requested memory";
298     cb(ErrorStatus::NONE, std::move(sampleBuffer), tokenValue);
299     return Void();
300 }
301 
run()302 int SampleDriver::run() {
303     android::hardware::configureRpcThreadpool(4, true);
304     if (registerAsService(mName) != android::OK) {
305         LOG(ERROR) << "Could not register service";
306         return 1;
307     }
308     android::hardware::joinRpcThreadpool();
309     LOG(ERROR) << "Service exited!";
310     return 1;
311 }
312 
copyRunTimePoolInfos(const RunTimePoolInfo & srcPool,const RunTimePoolInfo & dstPool)313 static void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
314     CHECK(srcPool.getBuffer() != nullptr);
315     CHECK(dstPool.getBuffer() != nullptr);
316     CHECK(srcPool.getSize() == dstPool.getSize());
317     std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
318     dstPool.flush();
319 }
320 
copyTo(const hidl_memory & dst)321 Return<ErrorStatus> SampleBuffer::copyTo(const hidl_memory& dst) {
322     const auto dstPool = RunTimePoolInfo::createFromHidlMemory(dst);
323     if (!dstPool.has_value()) {
324         LOG(ERROR) << "SampleBuffer::copyTo -- unable to map dst memory.";
325         return ErrorStatus::GENERAL_FAILURE;
326     }
327     const ErrorStatus validationStatus = kBuffer->validateCopyTo(dstPool->getSize());
328     if (validationStatus != ErrorStatus::NONE) {
329         return validationStatus;
330     }
331     const auto srcPool = kBuffer->createRunTimePoolInfo();
332     copyRunTimePoolInfos(srcPool, dstPool.value());
333     return ErrorStatus::NONE;
334 }
335 
copyFromInternal(const hidl_memory & src,const hidl_vec<uint32_t> & dimensions,const std::shared_ptr<ManagedBuffer> & bufferWrapper)336 static ErrorStatus copyFromInternal(const hidl_memory& src, const hidl_vec<uint32_t>& dimensions,
337                                     const std::shared_ptr<ManagedBuffer>& bufferWrapper) {
338     CHECK(bufferWrapper != nullptr);
339     const auto srcPool = RunTimePoolInfo::createFromHidlMemory(src);
340     if (!srcPool.has_value()) {
341         LOG(ERROR) << "SampleBuffer::copyFrom -- unable to map src memory.";
342         return ErrorStatus::GENERAL_FAILURE;
343     }
344     const ErrorStatus validationStatus =
345             bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize());
346     if (validationStatus != ErrorStatus::NONE) {
347         return validationStatus;
348     }
349     const auto dstPool = bufferWrapper->createRunTimePoolInfo();
350     copyRunTimePoolInfos(srcPool.value(), dstPool);
351     return ErrorStatus::NONE;
352 }
353 
copyFrom(const hidl_memory & src,const hidl_vec<uint32_t> & dimensions)354 Return<ErrorStatus> SampleBuffer::copyFrom(const hidl_memory& src,
355                                            const hidl_vec<uint32_t>& dimensions) {
356     const auto status = copyFromInternal(src, dimensions, kBuffer);
357     if (status == ErrorStatus::NONE) {
358         kBuffer->updateDimensions(dimensions);
359         kBuffer->setInitialized(true);
360     } else {
361         kBuffer->setInitialized(false);
362     }
363     return status;
364 }
365 
initialize()366 bool SamplePreparedModel::initialize() {
367     return setRunTimePoolInfosFromHidlMemories(&mPoolInfos, mModel.pools);
368 }
369 
370 static std::tuple<ErrorStatus, std::vector<RunTimePoolInfo>,
371                   std::vector<std::shared_ptr<ManagedBuffer>>>
createRunTimePoolInfos(const Request & request,const SampleDriver & driver,const SamplePreparedModel * preparedModel)372 createRunTimePoolInfos(const Request& request, const SampleDriver& driver,
373                        const SamplePreparedModel* preparedModel) {
374     std::vector<RunTimePoolInfo> requestPoolInfos;
375     std::vector<std::shared_ptr<ManagedBuffer>> bufferWrappers;
376     requestPoolInfos.reserve(request.pools.size());
377     bufferWrappers.reserve(request.pools.size());
378     for (uint32_t i = 0; i < request.pools.size(); i++) {
379         auto& pool = request.pools[i];
380         switch (pool.getDiscriminator()) {
381             case Request::MemoryPool::hidl_discriminator::hidlMemory: {
382                 auto buffer = RunTimePoolInfo::createFromHidlMemory(pool.hidlMemory());
383                 if (!buffer.has_value()) {
384                     LOG(ERROR) << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
385                     return {ErrorStatus::GENERAL_FAILURE, {}, {}};
386                 }
387                 requestPoolInfos.push_back(std::move(*buffer));
388                 bufferWrappers.push_back(nullptr);
389             } break;
390             case Request::MemoryPool::hidl_discriminator::token: {
391                 auto bufferWrapper = driver.getBufferTracker()->get(pool.token());
392                 if (bufferWrapper == nullptr) {
393                     return {ErrorStatus::INVALID_ARGUMENT, {}, {}};
394                 }
395                 const auto validationStatus =
396                         bufferWrapper->validateRequest(i, request, preparedModel);
397                 if (validationStatus != ErrorStatus::NONE) {
398                     return {validationStatus, {}, {}};
399                 }
400                 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
401                 bufferWrappers.push_back(std::move(bufferWrapper));
402             } break;
403         }
404     }
405     return {ErrorStatus::NONE, std::move(requestPoolInfos), std::move(bufferWrappers)};
406 }
407 
updateDeviceMemories(ErrorStatus status,const Request & request,const std::vector<std::shared_ptr<ManagedBuffer>> & bufferWrappers,const hidl_vec<OutputShape> & outputShapes)408 static ErrorStatus updateDeviceMemories(
409         ErrorStatus status, const Request& request,
410         const std::vector<std::shared_ptr<ManagedBuffer>>& bufferWrappers,
411         const hidl_vec<OutputShape>& outputShapes) {
412     if (status == ErrorStatus::NONE) {
413         for (uint32_t i = 0; i < request.outputs.size(); i++) {
414             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
415             const auto& pool = request.pools[poolIndex];
416             if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
417                 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
418                     return ErrorStatus::GENERAL_FAILURE;
419                 }
420             }
421         }
422         for (uint32_t i = 0; i < request.outputs.size(); i++) {
423             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
424             const auto& pool = request.pools[poolIndex];
425             if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
426                 bufferWrappers[poolIndex]->setInitialized(true);
427             }
428         }
429     } else if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
430         // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
431         // dimensions of the device memory are incorrectly specified. The driver should return
432         // GENERAL_FAILURE instead in this case.
433         for (uint32_t i = 0; i < request.outputs.size(); i++) {
434             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
435             const auto& pool = request.pools[poolIndex];
436             if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
437                 if (!outputShapes[i].isSufficient) {
438                     LOG(ERROR) << "Invalid dimensions for output " << i
439                                << ": actual shape = " << toString(outputShapes[i].dimensions);
440                     return ErrorStatus::GENERAL_FAILURE;
441                 }
442             }
443         }
444     }
445     return ErrorStatus::NONE;
446 }
447 
448 template <typename T_IExecutionCallback>
asyncExecute(const Request & request,MeasureTiming measure,time_point driverStart,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)449 void asyncExecute(const Request& request, MeasureTiming measure, time_point driverStart,
450                   const Model& model, const SampleDriver& driver,
451                   const SamplePreparedModel* preparedModel,
452                   const std::vector<RunTimePoolInfo>& poolInfos,
453                   const std::optional<Deadline>& deadline,
454                   const OptionalTimeoutDuration& loopTimeoutDuration,
455                   const sp<T_IExecutionCallback>& callback) {
456     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
457                  "SampleDriver::asyncExecute");
458 
459     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
460             createRunTimePoolInfos(request, driver, preparedModel);
461     if (poolStatus != ErrorStatus::NONE) {
462         notify(callback, poolStatus, {}, kNoTiming);
463         return;
464     }
465 
466     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
467                         "SampleDriver::asyncExecute");
468     CpuExecutor executor = driver.getExecutor();
469     if (loopTimeoutDuration.getDiscriminator() !=
470         OptionalTimeoutDuration::hidl_discriminator::none) {
471         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
472     }
473     if (deadline.has_value()) {
474         executor.setDeadline(*deadline);
475     }
476     time_point driverEnd, deviceStart, deviceEnd;
477     if (measure == MeasureTiming::YES) deviceStart = now();
478     int n = executor.run(model, request, poolInfos, requestPoolInfos);
479     if (measure == MeasureTiming::YES) deviceEnd = now();
480     VLOG(DRIVER) << "executor.run returned " << n;
481     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
482     hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
483 
484     // Update device memory metadata.
485     const ErrorStatus updateStatus =
486             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
487     if (updateStatus != ErrorStatus::NONE) {
488         notify(callback, updateStatus, {}, kNoTiming);
489         return;
490     }
491 
492     if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
493         driverEnd = now();
494         Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
495                          .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
496         VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
497         notify(callback, executionStatus, outputShapes, timing);
498     } else {
499         notify(callback, executionStatus, outputShapes, kNoTiming);
500     }
501 }
502 
503 template <typename T_IExecutionCallback>
executeBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)504 ErrorStatus executeBase(const Request& request, MeasureTiming measure, const Model& model,
505                         const SampleDriver& driver, const SamplePreparedModel* preparedModel,
506                         const std::vector<RunTimePoolInfo>& poolInfos,
507                         const OptionalTimePoint& halDeadline,
508                         const OptionalTimeoutDuration& loopTimeoutDuration,
509                         const sp<T_IExecutionCallback>& callback) {
510     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
511     VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
512 
513     time_point driverStart;
514     if (measure == MeasureTiming::YES) driverStart = now();
515 
516     if (callback.get() == nullptr) {
517         LOG(ERROR) << "invalid callback passed to executeBase";
518         return ErrorStatus::INVALID_ARGUMENT;
519     }
520     if (!validateRequest(request, model)) {
521         notify(callback, ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
522         return ErrorStatus::INVALID_ARGUMENT;
523     }
524     const auto deadline = makeDeadline(halDeadline);
525     if (hasDeadlinePassed(deadline)) {
526         notify(callback, ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming);
527         return ErrorStatus::NONE;
528     }
529 
530     // This thread is intentionally detached because the sample driver service
531     // is expected to live forever.
532     std::thread([&model, &driver, preparedModel, &poolInfos, request, measure, driverStart,
533                  deadline, loopTimeoutDuration, callback] {
534         asyncExecute(request, measure, driverStart, model, driver, preparedModel, poolInfos,
535                      deadline, loopTimeoutDuration, callback);
536     }).detach();
537 
538     return ErrorStatus::NONE;
539 }
540 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)541 Return<V1_0::ErrorStatus> SamplePreparedModel::execute(
542         const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) {
543     const ErrorStatus status = executeBase(convertToV1_3(request), MeasureTiming::NO, mModel,
544                                            *mDriver, this, mPoolInfos, {}, {}, callback);
545     return convertToV1_0(status);
546 }
547 
execute_1_2(const V1_0::Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)548 Return<V1_0::ErrorStatus> SamplePreparedModel::execute_1_2(
549         const V1_0::Request& request, MeasureTiming measure,
550         const sp<V1_2::IExecutionCallback>& callback) {
551     const ErrorStatus status = executeBase(convertToV1_3(request), measure, mModel, *mDriver, this,
552                                            mPoolInfos, {}, {}, callback);
553     return convertToV1_0(status);
554 }
555 
execute_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)556 Return<V1_3::ErrorStatus> SamplePreparedModel::execute_1_3(
557         const V1_3::Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
558         const OptionalTimeoutDuration& loopTimeoutDuration,
559         const sp<V1_3::IExecutionCallback>& callback) {
560     return executeBase(request, measure, mModel, *mDriver, this, mPoolInfos, deadline,
561                        loopTimeoutDuration, callback);
562 }
563 
executeSynchronouslyBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration)564 static std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> executeSynchronouslyBase(
565         const Request& request, MeasureTiming measure, const Model& model,
566         const SampleDriver& driver, const SamplePreparedModel* preparedModel,
567         const std::vector<RunTimePoolInfo>& poolInfos, const OptionalTimePoint& halDeadline,
568         const OptionalTimeoutDuration& loopTimeoutDuration) {
569     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
570                  "SampleDriver::executeSynchronouslyBase");
571     VLOG(DRIVER) << "executeSynchronouslyBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
572 
573     time_point driverStart, driverEnd, deviceStart, deviceEnd;
574     if (measure == MeasureTiming::YES) driverStart = now();
575 
576     if (!validateRequest(request, model)) {
577         return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
578     }
579     const auto deadline = makeDeadline(halDeadline);
580     if (hasDeadlinePassed(deadline)) {
581         return {ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming};
582     }
583 
584     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
585                         "SampleDriver::executeSynchronouslyBase");
586     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
587             createRunTimePoolInfos(request, driver, preparedModel);
588     if (poolStatus != ErrorStatus::NONE) {
589         return {poolStatus, {}, kNoTiming};
590     }
591 
592     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
593                         "SampleDriver::executeSynchronouslyBase");
594     CpuExecutor executor = driver.getExecutor();
595     if (loopTimeoutDuration.getDiscriminator() !=
596         OptionalTimeoutDuration::hidl_discriminator::none) {
597         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
598     }
599     if (deadline.has_value()) {
600         executor.setDeadline(*deadline);
601     }
602     if (measure == MeasureTiming::YES) deviceStart = now();
603     int n = executor.run(model, request, poolInfos, requestPoolInfos);
604     if (measure == MeasureTiming::YES) deviceEnd = now();
605     VLOG(DRIVER) << "executor.run returned " << n;
606     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
607     hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
608 
609     // Update device memory metadata.
610     const ErrorStatus updateStatus =
611             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
612     if (updateStatus != ErrorStatus::NONE) {
613         return {updateStatus, {}, kNoTiming};
614     }
615 
616     if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
617         driverEnd = now();
618         Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
619                          .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
620         VLOG(DRIVER) << "executeSynchronouslyBase timing = " << toString(timing);
621         return {executionStatus, std::move(outputShapes), timing};
622     }
623     return {executionStatus, std::move(outputShapes), kNoTiming};
624 }
625 
executeSynchronously(const V1_0::Request & request,MeasureTiming measure,executeSynchronously_cb cb)626 Return<void> SamplePreparedModel::executeSynchronously(const V1_0::Request& request,
627                                                        MeasureTiming measure,
628                                                        executeSynchronously_cb cb) {
629     auto [status, outputShapes, timing] = executeSynchronouslyBase(
630             convertToV1_3(request), measure, mModel, *mDriver, this, mPoolInfos, {}, {});
631     cb(convertToV1_0(status), std::move(outputShapes), timing);
632     return Void();
633 }
634 
executeSynchronously_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)635 Return<void> SamplePreparedModel::executeSynchronously_1_3(
636         const V1_3::Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
637         const OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
638     auto [status, outputShapes, timing] = executeSynchronouslyBase(
639             request, measure, mModel, *mDriver, this, mPoolInfos, deadline, loopTimeoutDuration);
640     cb(status, std::move(outputShapes), timing);
641     return Void();
642 }
643 
644 // The sample driver will finish the execution and then return.
executeFenced(const hal::Request & request,const hidl_vec<hidl_handle> & waitFor,MeasureTiming measure,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration,const OptionalTimeoutDuration & duration,executeFenced_cb cb)645 Return<void> SamplePreparedModel::executeFenced(
646         const hal::Request& request, const hidl_vec<hidl_handle>& waitFor, MeasureTiming measure,
647         const OptionalTimePoint& halDeadline, const OptionalTimeoutDuration& loopTimeoutDuration,
648         const OptionalTimeoutDuration& duration, executeFenced_cb cb) {
649     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
650                  "SamplePreparedModel::executeFenced");
651     VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(toString(request)) << ")";
652 
653     time_point driverStart, driverEnd, deviceStart, deviceEnd;
654     if (measure == MeasureTiming::YES) driverStart = now();
655 
656     if (!validateRequest(request, mModel, /*allowUnspecifiedOutput=*/false)) {
657         cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
658         return Void();
659     }
660     const auto deadline = makeDeadline(halDeadline);
661     if (hasDeadlinePassed(deadline)) {
662         cb(ErrorStatus::MISSED_DEADLINE_PERSISTENT, hidl_handle(nullptr), nullptr);
663         return Void();
664     }
665 
666     // Wait for the dependent events to signal
667     for (const auto& fenceHandle : waitFor) {
668         if (!fenceHandle.getNativeHandle()) {
669             cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
670             return Void();
671         }
672         int syncFenceFd = fenceHandle.getNativeHandle()->data[0];
673         if (syncWait(syncFenceFd, -1) != FenceState::SIGNALED) {
674             LOG(ERROR) << "syncWait failed";
675             cb(ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
676             return Void();
677         }
678     }
679 
680     // Update deadline if the timeout duration is closer than the deadline.
681     auto closestDeadline = deadline;
682     if (duration.getDiscriminator() != OptionalTimeoutDuration::hidl_discriminator::none) {
683         const auto timeoutDurationDeadline = makeDeadline(duration.nanoseconds());
684         if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
685             closestDeadline = timeoutDurationDeadline;
686         }
687     }
688 
689     time_point driverStartAfterFence;
690     if (measure == MeasureTiming::YES) driverStartAfterFence = now();
691 
692     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
693                         "SamplePreparedModel::executeFenced");
694     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
695             createRunTimePoolInfos(request, *mDriver, this);
696     if (poolStatus != ErrorStatus::NONE) {
697         cb(poolStatus, hidl_handle(nullptr), nullptr);
698         return Void();
699     }
700 
701     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
702                         "SamplePreparedModel::executeFenced");
703     CpuExecutor executor = mDriver->getExecutor();
704     if (loopTimeoutDuration.getDiscriminator() !=
705         OptionalTimeoutDuration::hidl_discriminator::none) {
706         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
707     }
708     if (closestDeadline.has_value()) {
709         executor.setDeadline(*closestDeadline);
710     }
711     if (measure == MeasureTiming::YES) deviceStart = now();
712     int n = executor.run(mModel, request, mPoolInfos, requestPoolInfos);
713     if (measure == MeasureTiming::YES) deviceEnd = now();
714     VLOG(DRIVER) << "executor.run returned " << n;
715     ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
716     if (executionStatus != ErrorStatus::NONE) {
717         cb(executionStatus, hidl_handle(nullptr), nullptr);
718         return Void();
719     }
720 
721     // Set output memories to the initialized state.
722     if (executionStatus == ErrorStatus::NONE) {
723         for (const auto& output : request.outputs) {
724             const uint32_t poolIndex = output.location.poolIndex;
725             const auto& pool = request.pools[poolIndex];
726             if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
727                 bufferWrappers[poolIndex]->setInitialized(true);
728             }
729         }
730     }
731 
732     Timing timingSinceLaunch = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
733     Timing timingAfterFence = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
734     if (measure == MeasureTiming::YES) {
735         driverEnd = now();
736         timingSinceLaunch = {
737                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
738                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
739         timingAfterFence = {
740                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
741                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStartAfterFence))};
742         VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << toString(timingSinceLaunch);
743         VLOG(DRIVER) << "executeFenced timingAfterFence = " << toString(timingAfterFence);
744     }
745     sp<SampleFencedExecutionCallback> fencedExecutionCallback =
746             new SampleFencedExecutionCallback(timingSinceLaunch, timingAfterFence, executionStatus);
747     cb(executionStatus, hidl_handle(nullptr), fencedExecutionCallback);
748     return Void();
749 }
750 
751 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
752 // the mapping until either (1) the memory is freed in the runtime, or (2) the
753 // burst object is destroyed. This allows for subsequent executions operating on
754 // pools that have been used before to reuse the mapping instead of mapping and
755 // unmapping the memory on each execution.
756 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
757    public:
BurstExecutorWithCache(const Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)758     BurstExecutorWithCache(const Model& model, const SampleDriver* driver,
759                            const std::vector<RunTimePoolInfo>& poolInfos)
760         : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
761 
isCacheEntryPresent(int32_t slot) const762     bool isCacheEntryPresent(int32_t slot) const override {
763         const auto it = mMemoryCache.find(slot);
764         return (it != mMemoryCache.end()) && it->second.has_value();
765     }
766 
addCacheEntry(const hidl_memory & memory,int32_t slot)767     void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
768         mMemoryCache[slot] = RunTimePoolInfo::createFromHidlMemory(memory);
769     }
770 
removeCacheEntry(int32_t slot)771     void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
772 
execute(const V1_0::Request & request,const std::vector<int32_t> & slots,MeasureTiming measure)773     std::tuple<V1_0::ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
774             const V1_0::Request& request, const std::vector<int32_t>& slots,
775             MeasureTiming measure) override {
776         NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
777                      "BurstExecutorWithCache::execute");
778 
779         time_point driverStart, driverEnd, deviceStart, deviceEnd;
780         if (measure == MeasureTiming::YES) driverStart = now();
781 
782         // ensure all relevant pools are valid
783         if (!std::all_of(slots.begin(), slots.end(),
784                          [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
785             return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
786         }
787 
788         // finish the request object (for validation)
789         hidl_vec<Request::MemoryPool> pools(slots.size());
790         std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
791             Request::MemoryPool pool;
792             pool.hidlMemory(mMemoryCache[slot]->getHidlMemory());
793             return pool;
794         });
795         Request fullRequest = {.inputs = request.inputs, .outputs = request.outputs};
796         fullRequest.pools = std::move(pools);
797 
798         // validate request object against the model
799         if (!validateRequest(fullRequest, mModel)) {
800             return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
801         }
802 
803         // select relevant entries from cache
804         std::vector<RunTimePoolInfo> requestPoolInfos;
805         requestPoolInfos.reserve(slots.size());
806         std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
807                        [this](int32_t slot) { return *mMemoryCache[slot]; });
808 
809         // execution
810         // Configuring the loop timeout duration is not supported. This is OK
811         // because burst does not support HAL 1.3 and hence does not support
812         // WHILE loops.
813         CpuExecutor executor = mDriver->getExecutor();
814         if (measure == MeasureTiming::YES) deviceStart = now();
815         int n = executor.run(mModel, fullRequest, mModelPoolInfos, requestPoolInfos);
816         if (measure == MeasureTiming::YES) deviceEnd = now();
817         VLOG(DRIVER) << "executor.run returned " << n;
818         V1_0::ErrorStatus executionStatus = convertToV1_0(convertResultCodeToErrorStatus(n));
819         hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
820         if (measure == MeasureTiming::YES && executionStatus == V1_0::ErrorStatus::NONE) {
821             driverEnd = now();
822             Timing timing = {
823                     .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
824                     .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
825             VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
826             return std::make_tuple(executionStatus, outputShapes, timing);
827         } else {
828             return std::make_tuple(executionStatus, outputShapes, kNoTiming);
829         }
830     }
831 
832    private:
833     const Model mModel;
834     const SampleDriver* const mDriver;
835     const std::vector<RunTimePoolInfo> mModelPoolInfos;
836     std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache;  // cached requestPoolInfos
837 };
838 
839 // This is the amount of time the ExecutionBurstServer should spend polling the
840 // FMQ to see if it has data available before it should fall back to waiting on
841 // the futex.
getPollingTimeWindow()842 static std::chrono::microseconds getPollingTimeWindow() {
843     constexpr int32_t defaultPollingTimeWindow = 50;
844 #ifdef NN_DEBUGGABLE
845     constexpr int32_t minPollingTimeWindow = 0;
846     const int32_t selectedPollingTimeWindow =
847             base::GetIntProperty("debug.nn.sample-driver-burst-polling-window",
848                                  defaultPollingTimeWindow, minPollingTimeWindow);
849     return std::chrono::microseconds{selectedPollingTimeWindow};
850 #else
851     return std::chrono::microseconds{defaultPollingTimeWindow};
852 #endif  // NN_DEBUGGABLE
853 }
854 
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)855 Return<void> SamplePreparedModel::configureExecutionBurst(
856         const sp<V1_2::IBurstCallback>& callback,
857         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
858         const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
859         configureExecutionBurst_cb cb) {
860     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
861                  "SampleDriver::configureExecutionBurst");
862 
863     const bool preferPowerOverLatency = (kPreference == ExecutionPreference::LOW_POWER);
864     const auto pollingTimeWindow =
865             (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
866 
867     // Alternatively, the burst could be configured via:
868     // const sp<V1_2::IBurstContext> burst =
869     //         ExecutionBurstServer::create(callback, requestChannel,
870     //                                      resultChannel, this,
871     //                                      pollingTimeWindow);
872     //
873     // However, this alternative representation does not include a memory map
874     // caching optimization, and adds overhead.
875     const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
876             std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
877     const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
878             callback, requestChannel, resultChannel, executorWithCache, pollingTimeWindow);
879 
880     if (burst == nullptr) {
881         cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
882     } else {
883         cb(V1_0::ErrorStatus::NONE, burst);
884     }
885 
886     return Void();
887 }
888 
889 }  // namespace sample_driver
890 }  // namespace nn
891 }  // namespace android
892