1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "SampleDriver"
18
19 #include "SampleDriver.h"
20
21 #include <android-base/logging.h>
22 #include <android-base/properties.h>
23 #include <hidl/LegacySupport.h>
24
25 #include <algorithm>
26 #include <chrono>
27 #include <map>
28 #include <memory>
29 #include <optional>
30 #include <set>
31 #include <thread>
32 #include <tuple>
33 #include <utility>
34 #include <vector>
35
36 #include "BufferTracker.h"
37 #include "CpuExecutor.h"
38 #include "ExecutionBurstServer.h"
39 #include "HalInterfaces.h"
40 #include "SampleDriverUtils.h"
41 #include "Tracing.h"
42 #include "ValidateHal.h"
43
44 namespace android {
45 namespace nn {
46 namespace sample_driver {
47
48 namespace {
49
50 using namespace hal;
51
52 using time_point = std::chrono::steady_clock::time_point;
53
now()54 auto now() {
55 return std::chrono::steady_clock::now();
56 };
57
microsecondsDuration(decltype(now ()) end,decltype(now ()) start)58 auto microsecondsDuration(decltype(now()) end, decltype(now()) start) {
59 return std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
60 };
61
62 } // namespace
63
64 static const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
65
getCapabilities(getCapabilities_cb cb)66 Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
67 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
68 "SampleDriver::getCapabilities");
69 return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
70 // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
71 cb(convertToV1_0(error), convertToV1_0(capabilities));
72 });
73 }
74
getCapabilities_1_1(getCapabilities_1_1_cb cb)75 Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
76 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
77 "SampleDriver::getCapabilities_1_1");
78 return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
79 // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
80 cb(convertToV1_0(error), convertToV1_1(capabilities));
81 });
82 }
83
getCapabilities_1_2(getCapabilities_1_2_cb cb)84 Return<void> SampleDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
85 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
86 "SampleDriver::getCapabilities_1_2");
87 return getCapabilities_1_3([&](ErrorStatus error, const V1_3::Capabilities& capabilities) {
88 // TODO(dgross): Do we need to check compliantWithV1_2(capabilities)?
89 cb(convertToV1_0(error), convertToV1_2(capabilities));
90 });
91 }
92
getVersionString(getVersionString_cb cb)93 Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
94 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
95 "SampleDriver::getVersionString");
96 cb(V1_0::ErrorStatus::NONE, "JUST_AN_EXAMPLE");
97 return Void();
98 }
99
getType(getType_cb cb)100 Return<void> SampleDriver::getType(getType_cb cb) {
101 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
102 cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
103 return Void();
104 }
105
getSupportedExtensions(getSupportedExtensions_cb cb)106 Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
107 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
108 "SampleDriver::getSupportedExtensions");
109 cb(V1_0::ErrorStatus::NONE, {/* No extensions. */});
110 return Void();
111 }
112
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)113 Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
114 getSupportedOperations_cb cb) {
115 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
116 "SampleDriver::getSupportedOperations");
117 if (!validateModel(model)) {
118 VLOG(DRIVER) << "getSupportedOperations";
119 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
120 return Void();
121 }
122 return getSupportedOperations_1_3(convertToV1_3(model),
123 [&](ErrorStatus status, const hidl_vec<bool>& supported) {
124 cb(convertToV1_0(status), supported);
125 });
126 }
127
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)128 Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
129 getSupportedOperations_1_1_cb cb) {
130 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
131 "SampleDriver::getSupportedOperations_1_1");
132 if (!validateModel(model)) {
133 VLOG(DRIVER) << "getSupportedOperations_1_1";
134 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
135 return Void();
136 }
137 return getSupportedOperations_1_3(convertToV1_3(model),
138 [&](ErrorStatus status, const hidl_vec<bool>& supported) {
139 cb(convertToV1_0(status), supported);
140 });
141 }
142
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)143 Return<void> SampleDriver::getSupportedOperations_1_2(const V1_2::Model& model,
144 getSupportedOperations_1_2_cb cb) {
145 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
146 "SampleDriver::getSupportedOperations_1_2");
147 if (!validateModel(model)) {
148 VLOG(DRIVER) << "getSupportedOperations_1_2";
149 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
150 return Void();
151 }
152 return getSupportedOperations_1_3(convertToV1_3(model),
153 [&](ErrorStatus status, const hidl_vec<bool>& supported) {
154 cb(convertToV1_0(status), supported);
155 });
156 }
157
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)158 Return<void> SampleDriver::getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) {
159 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
160 "SampleDriver::getNumberOfCacheFilesNeeded");
161 // Set both numbers to be 0 for cache not supported.
162 cb(V1_0::ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
163 return Void();
164 }
165
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)166 Return<V1_0::ErrorStatus> SampleDriver::prepareModel(
167 const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) {
168 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
169 const ErrorStatus status = prepareModelBase(
170 model, this, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, {}, callback);
171 return convertToV1_0(status);
172 }
173
prepareModel_1_1(const V1_1::Model & model,ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)174 Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_1(
175 const V1_1::Model& model, ExecutionPreference preference,
176 const sp<V1_0::IPreparedModelCallback>& callback) {
177 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
178 const ErrorStatus status =
179 prepareModelBase(model, this, preference, kDefaultPriority, {}, callback);
180 return convertToV1_0(status);
181 }
182
prepareModel_1_2(const V1_2::Model & model,ExecutionPreference preference,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)183 Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_2(
184 const V1_2::Model& model, ExecutionPreference preference, const hidl_vec<hidl_handle>&,
185 const hidl_vec<hidl_handle>&, const CacheToken&,
186 const sp<V1_2::IPreparedModelCallback>& callback) {
187 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
188 const ErrorStatus status =
189 prepareModelBase(model, this, preference, kDefaultPriority, {}, callback);
190 return convertToV1_0(status);
191 }
192
prepareModel_1_3(const V1_3::Model & model,ExecutionPreference preference,Priority priority,const OptionalTimePoint & deadline,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)193 Return<V1_3::ErrorStatus> SampleDriver::prepareModel_1_3(
194 const V1_3::Model& model, ExecutionPreference preference, Priority priority,
195 const OptionalTimePoint& deadline, const hidl_vec<hidl_handle>&,
196 const hidl_vec<hidl_handle>&, const CacheToken&,
197 const sp<V1_3::IPreparedModelCallback>& callback) {
198 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_3");
199 return prepareModelBase(model, this, preference, priority, deadline, callback);
200 }
201
prepareModelFromCache(const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)202 Return<V1_0::ErrorStatus> SampleDriver::prepareModelFromCache(
203 const hidl_vec<hidl_handle>&, const hidl_vec<hidl_handle>&, const CacheToken&,
204 const sp<V1_2::IPreparedModelCallback>& callback) {
205 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
206 "SampleDriver::prepareModelFromCache");
207 notify(callback, ErrorStatus::GENERAL_FAILURE, nullptr);
208 return V1_0::ErrorStatus::GENERAL_FAILURE;
209 }
210
prepareModelFromCache_1_3(const OptionalTimePoint &,const hidl_vec<hidl_handle> &,const hidl_vec<hidl_handle> &,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)211 Return<ErrorStatus> SampleDriver::prepareModelFromCache_1_3(
212 const OptionalTimePoint& /*deadline*/, const hidl_vec<hidl_handle>&,
213 const hidl_vec<hidl_handle>&, const CacheToken&,
214 const sp<V1_3::IPreparedModelCallback>& callback) {
215 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
216 "SampleDriver::prepareModelFromCache_1_3");
217 notify(callback, ErrorStatus::GENERAL_FAILURE, nullptr);
218 return ErrorStatus::GENERAL_FAILURE;
219 }
220
getStatus()221 Return<DeviceStatus> SampleDriver::getStatus() {
222 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
223 VLOG(DRIVER) << "getStatus()";
224 return DeviceStatus::AVAILABLE;
225 }
226
227 // Safely downcast an IPreparedModel object to SamplePreparedModel.
228 // This function will return nullptr if the IPreparedModel object is not originated from the sample
229 // driver process.
castToSamplePreparedModel(const sp<IPreparedModel> & preparedModel)230 static const SamplePreparedModel* castToSamplePreparedModel(
231 const sp<IPreparedModel>& preparedModel) {
232 if (preparedModel->isRemote()) {
233 return nullptr;
234 } else {
235 // This static_cast is safe because SamplePreparedModel is the only class that implements
236 // the IPreparedModel interface in the sample driver process.
237 return static_cast<const SamplePreparedModel*>(preparedModel.get());
238 }
239 }
240
allocate(const V1_3::BufferDesc & desc,const hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hidl_vec<V1_3::BufferRole> & inputRoles,const hidl_vec<V1_3::BufferRole> & outputRoles,allocate_cb cb)241 Return<void> SampleDriver::allocate(const V1_3::BufferDesc& desc,
242 const hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
243 const hidl_vec<V1_3::BufferRole>& inputRoles,
244 const hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) {
245 constexpr uint32_t kInvalidBufferToken = 0;
246
247 VLOG(DRIVER) << "SampleDriver::allocate";
248 std::set<PreparedModelRole> roles;
249 V1_3::Operand operand;
250 auto getModel = [](const sp<V1_3::IPreparedModel>& preparedModel) -> const V1_3::Model* {
251 const auto* samplePreparedModel = castToSamplePreparedModel(preparedModel);
252 if (samplePreparedModel == nullptr) {
253 LOG(ERROR) << "SampleDriver::allocate -- unknown remote IPreparedModel.";
254 return nullptr;
255 }
256 return samplePreparedModel->getModel();
257 };
258 if (!validateMemoryDesc(desc, preparedModels, inputRoles, outputRoles, getModel, &roles,
259 &operand)) {
260 LOG(ERROR) << "SampleDriver::allocate -- validation failed.";
261 cb(ErrorStatus::INVALID_ARGUMENT, nullptr, kInvalidBufferToken);
262 return Void();
263 }
264
265 if (isExtensionOperandType(operand.type)) {
266 LOG(ERROR) << "SampleDriver::allocate -- does not support extension type.";
267 cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
268 return Void();
269 }
270
271 // TODO(xusongw): Support allocating buffers with unknown dimensions or rank.
272 uint32_t size = nonExtensionOperandSizeOfData(operand.type, operand.dimensions);
273 VLOG(DRIVER) << "SampleDriver::allocate -- type = " << toString(operand.type)
274 << ", dimensions = " << toString(operand.dimensions) << ", size = " << size;
275 if (size == 0) {
276 LOG(ERROR) << "SampleDriver::allocate -- does not support dynamic output shape.";
277 cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
278 return Void();
279 }
280
281 auto bufferWrapper = ManagedBuffer::create(size, std::move(roles), std::move(operand));
282 if (bufferWrapper == nullptr) {
283 LOG(ERROR) << "SampleDriver::allocate -- not enough memory.";
284 cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
285 return Void();
286 }
287
288 auto token = mBufferTracker->add(bufferWrapper);
289 if (token == nullptr) {
290 LOG(ERROR) << "SampleDriver::allocate -- BufferTracker returned invalid token.";
291 cb(ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
292 return Void();
293 }
294
295 const uint32_t tokenValue = token->get();
296 sp<SampleBuffer> sampleBuffer = new SampleBuffer(std::move(bufferWrapper), std::move(token));
297 VLOG(DRIVER) << "SampleDriver::allocate -- successfully allocates the requested memory";
298 cb(ErrorStatus::NONE, std::move(sampleBuffer), tokenValue);
299 return Void();
300 }
301
run()302 int SampleDriver::run() {
303 android::hardware::configureRpcThreadpool(4, true);
304 if (registerAsService(mName) != android::OK) {
305 LOG(ERROR) << "Could not register service";
306 return 1;
307 }
308 android::hardware::joinRpcThreadpool();
309 LOG(ERROR) << "Service exited!";
310 return 1;
311 }
312
copyRunTimePoolInfos(const RunTimePoolInfo & srcPool,const RunTimePoolInfo & dstPool)313 static void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
314 CHECK(srcPool.getBuffer() != nullptr);
315 CHECK(dstPool.getBuffer() != nullptr);
316 CHECK(srcPool.getSize() == dstPool.getSize());
317 std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
318 dstPool.flush();
319 }
320
copyTo(const hidl_memory & dst)321 Return<ErrorStatus> SampleBuffer::copyTo(const hidl_memory& dst) {
322 const auto dstPool = RunTimePoolInfo::createFromHidlMemory(dst);
323 if (!dstPool.has_value()) {
324 LOG(ERROR) << "SampleBuffer::copyTo -- unable to map dst memory.";
325 return ErrorStatus::GENERAL_FAILURE;
326 }
327 const ErrorStatus validationStatus = kBuffer->validateCopyTo(dstPool->getSize());
328 if (validationStatus != ErrorStatus::NONE) {
329 return validationStatus;
330 }
331 const auto srcPool = kBuffer->createRunTimePoolInfo();
332 copyRunTimePoolInfos(srcPool, dstPool.value());
333 return ErrorStatus::NONE;
334 }
335
copyFromInternal(const hidl_memory & src,const hidl_vec<uint32_t> & dimensions,const std::shared_ptr<ManagedBuffer> & bufferWrapper)336 static ErrorStatus copyFromInternal(const hidl_memory& src, const hidl_vec<uint32_t>& dimensions,
337 const std::shared_ptr<ManagedBuffer>& bufferWrapper) {
338 CHECK(bufferWrapper != nullptr);
339 const auto srcPool = RunTimePoolInfo::createFromHidlMemory(src);
340 if (!srcPool.has_value()) {
341 LOG(ERROR) << "SampleBuffer::copyFrom -- unable to map src memory.";
342 return ErrorStatus::GENERAL_FAILURE;
343 }
344 const ErrorStatus validationStatus =
345 bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize());
346 if (validationStatus != ErrorStatus::NONE) {
347 return validationStatus;
348 }
349 const auto dstPool = bufferWrapper->createRunTimePoolInfo();
350 copyRunTimePoolInfos(srcPool.value(), dstPool);
351 return ErrorStatus::NONE;
352 }
353
copyFrom(const hidl_memory & src,const hidl_vec<uint32_t> & dimensions)354 Return<ErrorStatus> SampleBuffer::copyFrom(const hidl_memory& src,
355 const hidl_vec<uint32_t>& dimensions) {
356 const auto status = copyFromInternal(src, dimensions, kBuffer);
357 if (status == ErrorStatus::NONE) {
358 kBuffer->updateDimensions(dimensions);
359 kBuffer->setInitialized(true);
360 } else {
361 kBuffer->setInitialized(false);
362 }
363 return status;
364 }
365
initialize()366 bool SamplePreparedModel::initialize() {
367 return setRunTimePoolInfosFromHidlMemories(&mPoolInfos, mModel.pools);
368 }
369
370 static std::tuple<ErrorStatus, std::vector<RunTimePoolInfo>,
371 std::vector<std::shared_ptr<ManagedBuffer>>>
createRunTimePoolInfos(const Request & request,const SampleDriver & driver,const SamplePreparedModel * preparedModel)372 createRunTimePoolInfos(const Request& request, const SampleDriver& driver,
373 const SamplePreparedModel* preparedModel) {
374 std::vector<RunTimePoolInfo> requestPoolInfos;
375 std::vector<std::shared_ptr<ManagedBuffer>> bufferWrappers;
376 requestPoolInfos.reserve(request.pools.size());
377 bufferWrappers.reserve(request.pools.size());
378 for (uint32_t i = 0; i < request.pools.size(); i++) {
379 auto& pool = request.pools[i];
380 switch (pool.getDiscriminator()) {
381 case Request::MemoryPool::hidl_discriminator::hidlMemory: {
382 auto buffer = RunTimePoolInfo::createFromHidlMemory(pool.hidlMemory());
383 if (!buffer.has_value()) {
384 LOG(ERROR) << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
385 return {ErrorStatus::GENERAL_FAILURE, {}, {}};
386 }
387 requestPoolInfos.push_back(std::move(*buffer));
388 bufferWrappers.push_back(nullptr);
389 } break;
390 case Request::MemoryPool::hidl_discriminator::token: {
391 auto bufferWrapper = driver.getBufferTracker()->get(pool.token());
392 if (bufferWrapper == nullptr) {
393 return {ErrorStatus::INVALID_ARGUMENT, {}, {}};
394 }
395 const auto validationStatus =
396 bufferWrapper->validateRequest(i, request, preparedModel);
397 if (validationStatus != ErrorStatus::NONE) {
398 return {validationStatus, {}, {}};
399 }
400 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
401 bufferWrappers.push_back(std::move(bufferWrapper));
402 } break;
403 }
404 }
405 return {ErrorStatus::NONE, std::move(requestPoolInfos), std::move(bufferWrappers)};
406 }
407
updateDeviceMemories(ErrorStatus status,const Request & request,const std::vector<std::shared_ptr<ManagedBuffer>> & bufferWrappers,const hidl_vec<OutputShape> & outputShapes)408 static ErrorStatus updateDeviceMemories(
409 ErrorStatus status, const Request& request,
410 const std::vector<std::shared_ptr<ManagedBuffer>>& bufferWrappers,
411 const hidl_vec<OutputShape>& outputShapes) {
412 if (status == ErrorStatus::NONE) {
413 for (uint32_t i = 0; i < request.outputs.size(); i++) {
414 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
415 const auto& pool = request.pools[poolIndex];
416 if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
417 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
418 return ErrorStatus::GENERAL_FAILURE;
419 }
420 }
421 }
422 for (uint32_t i = 0; i < request.outputs.size(); i++) {
423 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
424 const auto& pool = request.pools[poolIndex];
425 if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
426 bufferWrappers[poolIndex]->setInitialized(true);
427 }
428 }
429 } else if (status == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
430 // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
431 // dimensions of the device memory are incorrectly specified. The driver should return
432 // GENERAL_FAILURE instead in this case.
433 for (uint32_t i = 0; i < request.outputs.size(); i++) {
434 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
435 const auto& pool = request.pools[poolIndex];
436 if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
437 if (!outputShapes[i].isSufficient) {
438 LOG(ERROR) << "Invalid dimensions for output " << i
439 << ": actual shape = " << toString(outputShapes[i].dimensions);
440 return ErrorStatus::GENERAL_FAILURE;
441 }
442 }
443 }
444 }
445 return ErrorStatus::NONE;
446 }
447
448 template <typename T_IExecutionCallback>
asyncExecute(const Request & request,MeasureTiming measure,time_point driverStart,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const std::optional<Deadline> & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)449 void asyncExecute(const Request& request, MeasureTiming measure, time_point driverStart,
450 const Model& model, const SampleDriver& driver,
451 const SamplePreparedModel* preparedModel,
452 const std::vector<RunTimePoolInfo>& poolInfos,
453 const std::optional<Deadline>& deadline,
454 const OptionalTimeoutDuration& loopTimeoutDuration,
455 const sp<T_IExecutionCallback>& callback) {
456 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
457 "SampleDriver::asyncExecute");
458
459 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
460 createRunTimePoolInfos(request, driver, preparedModel);
461 if (poolStatus != ErrorStatus::NONE) {
462 notify(callback, poolStatus, {}, kNoTiming);
463 return;
464 }
465
466 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
467 "SampleDriver::asyncExecute");
468 CpuExecutor executor = driver.getExecutor();
469 if (loopTimeoutDuration.getDiscriminator() !=
470 OptionalTimeoutDuration::hidl_discriminator::none) {
471 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
472 }
473 if (deadline.has_value()) {
474 executor.setDeadline(*deadline);
475 }
476 time_point driverEnd, deviceStart, deviceEnd;
477 if (measure == MeasureTiming::YES) deviceStart = now();
478 int n = executor.run(model, request, poolInfos, requestPoolInfos);
479 if (measure == MeasureTiming::YES) deviceEnd = now();
480 VLOG(DRIVER) << "executor.run returned " << n;
481 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
482 hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
483
484 // Update device memory metadata.
485 const ErrorStatus updateStatus =
486 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
487 if (updateStatus != ErrorStatus::NONE) {
488 notify(callback, updateStatus, {}, kNoTiming);
489 return;
490 }
491
492 if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
493 driverEnd = now();
494 Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
495 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
496 VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
497 notify(callback, executionStatus, outputShapes, timing);
498 } else {
499 notify(callback, executionStatus, outputShapes, kNoTiming);
500 }
501 }
502
503 template <typename T_IExecutionCallback>
executeBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)504 ErrorStatus executeBase(const Request& request, MeasureTiming measure, const Model& model,
505 const SampleDriver& driver, const SamplePreparedModel* preparedModel,
506 const std::vector<RunTimePoolInfo>& poolInfos,
507 const OptionalTimePoint& halDeadline,
508 const OptionalTimeoutDuration& loopTimeoutDuration,
509 const sp<T_IExecutionCallback>& callback) {
510 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
511 VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
512
513 time_point driverStart;
514 if (measure == MeasureTiming::YES) driverStart = now();
515
516 if (callback.get() == nullptr) {
517 LOG(ERROR) << "invalid callback passed to executeBase";
518 return ErrorStatus::INVALID_ARGUMENT;
519 }
520 if (!validateRequest(request, model)) {
521 notify(callback, ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
522 return ErrorStatus::INVALID_ARGUMENT;
523 }
524 const auto deadline = makeDeadline(halDeadline);
525 if (hasDeadlinePassed(deadline)) {
526 notify(callback, ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming);
527 return ErrorStatus::NONE;
528 }
529
530 // This thread is intentionally detached because the sample driver service
531 // is expected to live forever.
532 std::thread([&model, &driver, preparedModel, &poolInfos, request, measure, driverStart,
533 deadline, loopTimeoutDuration, callback] {
534 asyncExecute(request, measure, driverStart, model, driver, preparedModel, poolInfos,
535 deadline, loopTimeoutDuration, callback);
536 }).detach();
537
538 return ErrorStatus::NONE;
539 }
540
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)541 Return<V1_0::ErrorStatus> SamplePreparedModel::execute(
542 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) {
543 const ErrorStatus status = executeBase(convertToV1_3(request), MeasureTiming::NO, mModel,
544 *mDriver, this, mPoolInfos, {}, {}, callback);
545 return convertToV1_0(status);
546 }
547
execute_1_2(const V1_0::Request & request,MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)548 Return<V1_0::ErrorStatus> SamplePreparedModel::execute_1_2(
549 const V1_0::Request& request, MeasureTiming measure,
550 const sp<V1_2::IExecutionCallback>& callback) {
551 const ErrorStatus status = executeBase(convertToV1_3(request), measure, mModel, *mDriver, this,
552 mPoolInfos, {}, {}, callback);
553 return convertToV1_0(status);
554 }
555
execute_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)556 Return<V1_3::ErrorStatus> SamplePreparedModel::execute_1_3(
557 const V1_3::Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
558 const OptionalTimeoutDuration& loopTimeoutDuration,
559 const sp<V1_3::IExecutionCallback>& callback) {
560 return executeBase(request, measure, mModel, *mDriver, this, mPoolInfos, deadline,
561 loopTimeoutDuration, callback);
562 }
563
executeSynchronouslyBase(const Request & request,MeasureTiming measure,const Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration)564 static std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> executeSynchronouslyBase(
565 const Request& request, MeasureTiming measure, const Model& model,
566 const SampleDriver& driver, const SamplePreparedModel* preparedModel,
567 const std::vector<RunTimePoolInfo>& poolInfos, const OptionalTimePoint& halDeadline,
568 const OptionalTimeoutDuration& loopTimeoutDuration) {
569 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
570 "SampleDriver::executeSynchronouslyBase");
571 VLOG(DRIVER) << "executeSynchronouslyBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
572
573 time_point driverStart, driverEnd, deviceStart, deviceEnd;
574 if (measure == MeasureTiming::YES) driverStart = now();
575
576 if (!validateRequest(request, model)) {
577 return {ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
578 }
579 const auto deadline = makeDeadline(halDeadline);
580 if (hasDeadlinePassed(deadline)) {
581 return {ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming};
582 }
583
584 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
585 "SampleDriver::executeSynchronouslyBase");
586 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
587 createRunTimePoolInfos(request, driver, preparedModel);
588 if (poolStatus != ErrorStatus::NONE) {
589 return {poolStatus, {}, kNoTiming};
590 }
591
592 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
593 "SampleDriver::executeSynchronouslyBase");
594 CpuExecutor executor = driver.getExecutor();
595 if (loopTimeoutDuration.getDiscriminator() !=
596 OptionalTimeoutDuration::hidl_discriminator::none) {
597 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
598 }
599 if (deadline.has_value()) {
600 executor.setDeadline(*deadline);
601 }
602 if (measure == MeasureTiming::YES) deviceStart = now();
603 int n = executor.run(model, request, poolInfos, requestPoolInfos);
604 if (measure == MeasureTiming::YES) deviceEnd = now();
605 VLOG(DRIVER) << "executor.run returned " << n;
606 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
607 hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
608
609 // Update device memory metadata.
610 const ErrorStatus updateStatus =
611 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
612 if (updateStatus != ErrorStatus::NONE) {
613 return {updateStatus, {}, kNoTiming};
614 }
615
616 if (measure == MeasureTiming::YES && executionStatus == ErrorStatus::NONE) {
617 driverEnd = now();
618 Timing timing = {.timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
619 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
620 VLOG(DRIVER) << "executeSynchronouslyBase timing = " << toString(timing);
621 return {executionStatus, std::move(outputShapes), timing};
622 }
623 return {executionStatus, std::move(outputShapes), kNoTiming};
624 }
625
executeSynchronously(const V1_0::Request & request,MeasureTiming measure,executeSynchronously_cb cb)626 Return<void> SamplePreparedModel::executeSynchronously(const V1_0::Request& request,
627 MeasureTiming measure,
628 executeSynchronously_cb cb) {
629 auto [status, outputShapes, timing] = executeSynchronouslyBase(
630 convertToV1_3(request), measure, mModel, *mDriver, this, mPoolInfos, {}, {});
631 cb(convertToV1_0(status), std::move(outputShapes), timing);
632 return Void();
633 }
634
executeSynchronously_1_3(const V1_3::Request & request,MeasureTiming measure,const OptionalTimePoint & deadline,const OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)635 Return<void> SamplePreparedModel::executeSynchronously_1_3(
636 const V1_3::Request& request, MeasureTiming measure, const OptionalTimePoint& deadline,
637 const OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
638 auto [status, outputShapes, timing] = executeSynchronouslyBase(
639 request, measure, mModel, *mDriver, this, mPoolInfos, deadline, loopTimeoutDuration);
640 cb(status, std::move(outputShapes), timing);
641 return Void();
642 }
643
644 // The sample driver will finish the execution and then return.
executeFenced(const hal::Request & request,const hidl_vec<hidl_handle> & waitFor,MeasureTiming measure,const OptionalTimePoint & halDeadline,const OptionalTimeoutDuration & loopTimeoutDuration,const OptionalTimeoutDuration & duration,executeFenced_cb cb)645 Return<void> SamplePreparedModel::executeFenced(
646 const hal::Request& request, const hidl_vec<hidl_handle>& waitFor, MeasureTiming measure,
647 const OptionalTimePoint& halDeadline, const OptionalTimeoutDuration& loopTimeoutDuration,
648 const OptionalTimeoutDuration& duration, executeFenced_cb cb) {
649 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
650 "SamplePreparedModel::executeFenced");
651 VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(toString(request)) << ")";
652
653 time_point driverStart, driverEnd, deviceStart, deviceEnd;
654 if (measure == MeasureTiming::YES) driverStart = now();
655
656 if (!validateRequest(request, mModel, /*allowUnspecifiedOutput=*/false)) {
657 cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
658 return Void();
659 }
660 const auto deadline = makeDeadline(halDeadline);
661 if (hasDeadlinePassed(deadline)) {
662 cb(ErrorStatus::MISSED_DEADLINE_PERSISTENT, hidl_handle(nullptr), nullptr);
663 return Void();
664 }
665
666 // Wait for the dependent events to signal
667 for (const auto& fenceHandle : waitFor) {
668 if (!fenceHandle.getNativeHandle()) {
669 cb(ErrorStatus::INVALID_ARGUMENT, hidl_handle(nullptr), nullptr);
670 return Void();
671 }
672 int syncFenceFd = fenceHandle.getNativeHandle()->data[0];
673 if (syncWait(syncFenceFd, -1) != FenceState::SIGNALED) {
674 LOG(ERROR) << "syncWait failed";
675 cb(ErrorStatus::GENERAL_FAILURE, hidl_handle(nullptr), nullptr);
676 return Void();
677 }
678 }
679
680 // Update deadline if the timeout duration is closer than the deadline.
681 auto closestDeadline = deadline;
682 if (duration.getDiscriminator() != OptionalTimeoutDuration::hidl_discriminator::none) {
683 const auto timeoutDurationDeadline = makeDeadline(duration.nanoseconds());
684 if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
685 closestDeadline = timeoutDurationDeadline;
686 }
687 }
688
689 time_point driverStartAfterFence;
690 if (measure == MeasureTiming::YES) driverStartAfterFence = now();
691
692 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
693 "SamplePreparedModel::executeFenced");
694 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
695 createRunTimePoolInfos(request, *mDriver, this);
696 if (poolStatus != ErrorStatus::NONE) {
697 cb(poolStatus, hidl_handle(nullptr), nullptr);
698 return Void();
699 }
700
701 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
702 "SamplePreparedModel::executeFenced");
703 CpuExecutor executor = mDriver->getExecutor();
704 if (loopTimeoutDuration.getDiscriminator() !=
705 OptionalTimeoutDuration::hidl_discriminator::none) {
706 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
707 }
708 if (closestDeadline.has_value()) {
709 executor.setDeadline(*closestDeadline);
710 }
711 if (measure == MeasureTiming::YES) deviceStart = now();
712 int n = executor.run(mModel, request, mPoolInfos, requestPoolInfos);
713 if (measure == MeasureTiming::YES) deviceEnd = now();
714 VLOG(DRIVER) << "executor.run returned " << n;
715 ErrorStatus executionStatus = convertResultCodeToErrorStatus(n);
716 if (executionStatus != ErrorStatus::NONE) {
717 cb(executionStatus, hidl_handle(nullptr), nullptr);
718 return Void();
719 }
720
721 // Set output memories to the initialized state.
722 if (executionStatus == ErrorStatus::NONE) {
723 for (const auto& output : request.outputs) {
724 const uint32_t poolIndex = output.location.poolIndex;
725 const auto& pool = request.pools[poolIndex];
726 if (pool.getDiscriminator() == Request::MemoryPool::hidl_discriminator::token) {
727 bufferWrappers[poolIndex]->setInitialized(true);
728 }
729 }
730 }
731
732 Timing timingSinceLaunch = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
733 Timing timingAfterFence = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
734 if (measure == MeasureTiming::YES) {
735 driverEnd = now();
736 timingSinceLaunch = {
737 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
738 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
739 timingAfterFence = {
740 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
741 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStartAfterFence))};
742 VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << toString(timingSinceLaunch);
743 VLOG(DRIVER) << "executeFenced timingAfterFence = " << toString(timingAfterFence);
744 }
745 sp<SampleFencedExecutionCallback> fencedExecutionCallback =
746 new SampleFencedExecutionCallback(timingSinceLaunch, timingAfterFence, executionStatus);
747 cb(executionStatus, hidl_handle(nullptr), fencedExecutionCallback);
748 return Void();
749 }
750
751 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
752 // the mapping until either (1) the memory is freed in the runtime, or (2) the
753 // burst object is destroyed. This allows for subsequent executions operating on
754 // pools that have been used before to reuse the mapping instead of mapping and
755 // unmapping the memory on each execution.
756 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
757 public:
BurstExecutorWithCache(const Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)758 BurstExecutorWithCache(const Model& model, const SampleDriver* driver,
759 const std::vector<RunTimePoolInfo>& poolInfos)
760 : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
761
isCacheEntryPresent(int32_t slot) const762 bool isCacheEntryPresent(int32_t slot) const override {
763 const auto it = mMemoryCache.find(slot);
764 return (it != mMemoryCache.end()) && it->second.has_value();
765 }
766
addCacheEntry(const hidl_memory & memory,int32_t slot)767 void addCacheEntry(const hidl_memory& memory, int32_t slot) override {
768 mMemoryCache[slot] = RunTimePoolInfo::createFromHidlMemory(memory);
769 }
770
removeCacheEntry(int32_t slot)771 void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
772
execute(const V1_0::Request & request,const std::vector<int32_t> & slots,MeasureTiming measure)773 std::tuple<V1_0::ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
774 const V1_0::Request& request, const std::vector<int32_t>& slots,
775 MeasureTiming measure) override {
776 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
777 "BurstExecutorWithCache::execute");
778
779 time_point driverStart, driverEnd, deviceStart, deviceEnd;
780 if (measure == MeasureTiming::YES) driverStart = now();
781
782 // ensure all relevant pools are valid
783 if (!std::all_of(slots.begin(), slots.end(),
784 [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
785 return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
786 }
787
788 // finish the request object (for validation)
789 hidl_vec<Request::MemoryPool> pools(slots.size());
790 std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
791 Request::MemoryPool pool;
792 pool.hidlMemory(mMemoryCache[slot]->getHidlMemory());
793 return pool;
794 });
795 Request fullRequest = {.inputs = request.inputs, .outputs = request.outputs};
796 fullRequest.pools = std::move(pools);
797
798 // validate request object against the model
799 if (!validateRequest(fullRequest, mModel)) {
800 return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
801 }
802
803 // select relevant entries from cache
804 std::vector<RunTimePoolInfo> requestPoolInfos;
805 requestPoolInfos.reserve(slots.size());
806 std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
807 [this](int32_t slot) { return *mMemoryCache[slot]; });
808
809 // execution
810 // Configuring the loop timeout duration is not supported. This is OK
811 // because burst does not support HAL 1.3 and hence does not support
812 // WHILE loops.
813 CpuExecutor executor = mDriver->getExecutor();
814 if (measure == MeasureTiming::YES) deviceStart = now();
815 int n = executor.run(mModel, fullRequest, mModelPoolInfos, requestPoolInfos);
816 if (measure == MeasureTiming::YES) deviceEnd = now();
817 VLOG(DRIVER) << "executor.run returned " << n;
818 V1_0::ErrorStatus executionStatus = convertToV1_0(convertResultCodeToErrorStatus(n));
819 hidl_vec<OutputShape> outputShapes = executor.getOutputShapes();
820 if (measure == MeasureTiming::YES && executionStatus == V1_0::ErrorStatus::NONE) {
821 driverEnd = now();
822 Timing timing = {
823 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
824 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
825 VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
826 return std::make_tuple(executionStatus, outputShapes, timing);
827 } else {
828 return std::make_tuple(executionStatus, outputShapes, kNoTiming);
829 }
830 }
831
832 private:
833 const Model mModel;
834 const SampleDriver* const mDriver;
835 const std::vector<RunTimePoolInfo> mModelPoolInfos;
836 std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache; // cached requestPoolInfos
837 };
838
839 // This is the amount of time the ExecutionBurstServer should spend polling the
840 // FMQ to see if it has data available before it should fall back to waiting on
841 // the futex.
getPollingTimeWindow()842 static std::chrono::microseconds getPollingTimeWindow() {
843 constexpr int32_t defaultPollingTimeWindow = 50;
844 #ifdef NN_DEBUGGABLE
845 constexpr int32_t minPollingTimeWindow = 0;
846 const int32_t selectedPollingTimeWindow =
847 base::GetIntProperty("debug.nn.sample-driver-burst-polling-window",
848 defaultPollingTimeWindow, minPollingTimeWindow);
849 return std::chrono::microseconds{selectedPollingTimeWindow};
850 #else
851 return std::chrono::microseconds{defaultPollingTimeWindow};
852 #endif // NN_DEBUGGABLE
853 }
854
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)855 Return<void> SamplePreparedModel::configureExecutionBurst(
856 const sp<V1_2::IBurstCallback>& callback,
857 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
858 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
859 configureExecutionBurst_cb cb) {
860 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
861 "SampleDriver::configureExecutionBurst");
862
863 const bool preferPowerOverLatency = (kPreference == ExecutionPreference::LOW_POWER);
864 const auto pollingTimeWindow =
865 (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
866
867 // Alternatively, the burst could be configured via:
868 // const sp<V1_2::IBurstContext> burst =
869 // ExecutionBurstServer::create(callback, requestChannel,
870 // resultChannel, this,
871 // pollingTimeWindow);
872 //
873 // However, this alternative representation does not include a memory map
874 // caching optimization, and adds overhead.
875 const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
876 std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
877 const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
878 callback, requestChannel, resultChannel, executorWithCache, pollingTimeWindow);
879
880 if (burst == nullptr) {
881 cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
882 } else {
883 cb(V1_0::ErrorStatus::NONE, burst);
884 }
885
886 return Void();
887 }
888
889 } // namespace sample_driver
890 } // namespace nn
891 } // namespace android
892