1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ExecutionBurstController"
18
19 #include "ExecutionBurstController.h"
20
21 #include <android-base/logging.h>
22
23 #include <algorithm>
24 #include <cstring>
25 #include <limits>
26 #include <memory>
27 #include <string>
28 #include <tuple>
29 #include <utility>
30 #include <vector>
31
32 #include "HalInterfaces.h"
33 #include "Tracing.h"
34 #include "Utils.h"
35
36 namespace android::nn {
37 namespace {
38
39 using namespace hal;
40
41 using V1_2::FmqRequestDatum;
42 using V1_2::FmqResultDatum;
43 using V1_2::IBurstCallback;
44 using V1_2::IBurstContext;
45 using FmqRequestDescriptor = hardware::MQDescriptorSync<FmqRequestDatum>;
46 using FmqResultDescriptor = hardware::MQDescriptorSync<FmqResultDatum>;
47
48 constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
49 std::numeric_limits<uint64_t>::max()};
50
51 class BurstContextDeathHandler : public hidl_death_recipient {
52 public:
53 using Callback = std::function<void()>;
54
BurstContextDeathHandler(const Callback & onDeathCallback)55 BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
56 CHECK(onDeathCallback != nullptr);
57 }
58
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)59 void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
60 LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
61 mOnDeathCallback();
62 }
63
64 private:
65 const Callback mOnDeathCallback;
66 };
67
68 } // anonymous namespace
69
70 // serialize a request into a packet
serialize(const V1_0::Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)71 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, MeasureTiming measure,
72 const std::vector<int32_t>& slots) {
73 // count how many elements need to be sent for a request
74 size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
75 for (const auto& input : request.inputs) {
76 count += input.dimensions.size();
77 }
78 for (const auto& output : request.outputs) {
79 count += output.dimensions.size();
80 }
81
82 // create buffer to temporarily store elements
83 std::vector<FmqRequestDatum> data;
84 data.reserve(count);
85
86 // package packetInfo
87 {
88 FmqRequestDatum datum;
89 datum.packetInformation(
90 {/*.packetSize=*/static_cast<uint32_t>(count),
91 /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
92 /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
93 /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
94 data.push_back(datum);
95 }
96
97 // package input data
98 for (const auto& input : request.inputs) {
99 // package operand information
100 FmqRequestDatum datum;
101 datum.inputOperandInformation(
102 {/*.hasNoValue=*/input.hasNoValue,
103 /*.location=*/input.location,
104 /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
105 data.push_back(datum);
106
107 // package operand dimensions
108 for (uint32_t dimension : input.dimensions) {
109 FmqRequestDatum datum;
110 datum.inputOperandDimensionValue(dimension);
111 data.push_back(datum);
112 }
113 }
114
115 // package output data
116 for (const auto& output : request.outputs) {
117 // package operand information
118 FmqRequestDatum datum;
119 datum.outputOperandInformation(
120 {/*.hasNoValue=*/output.hasNoValue,
121 /*.location=*/output.location,
122 /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
123 data.push_back(datum);
124
125 // package operand dimensions
126 for (uint32_t dimension : output.dimensions) {
127 FmqRequestDatum datum;
128 datum.outputOperandDimensionValue(dimension);
129 data.push_back(datum);
130 }
131 }
132
133 // package pool identifier
134 for (int32_t slot : slots) {
135 FmqRequestDatum datum;
136 datum.poolIdentifier(slot);
137 data.push_back(datum);
138 }
139
140 // package measureTiming
141 {
142 FmqRequestDatum datum;
143 datum.measureTiming(measure);
144 data.push_back(datum);
145 }
146
147 // return packet
148 return data;
149 }
150
151 // deserialize a packet into the result
deserialize(const std::vector<FmqResultDatum> & data)152 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
153 const std::vector<FmqResultDatum>& data) {
154 using discriminator = FmqResultDatum::hidl_discriminator;
155
156 std::vector<OutputShape> outputShapes;
157 size_t index = 0;
158
159 // validate packet information
160 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
161 LOG(ERROR) << "FMQ Result packet ill-formed";
162 return std::nullopt;
163 }
164
165 // unpackage packet information
166 const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
167 index++;
168 const uint32_t packetSize = packetInfo.packetSize;
169 const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
170 const uint32_t numberOfOperands = packetInfo.numberOfOperands;
171
172 // verify packet size
173 if (data.size() != packetSize) {
174 LOG(ERROR) << "FMQ Result packet ill-formed";
175 return std::nullopt;
176 }
177
178 // unpackage operands
179 for (size_t operand = 0; operand < numberOfOperands; ++operand) {
180 // validate operand information
181 if (data[index].getDiscriminator() != discriminator::operandInformation) {
182 LOG(ERROR) << "FMQ Result packet ill-formed";
183 return std::nullopt;
184 }
185
186 // unpackage operand information
187 const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
188 index++;
189 const bool isSufficient = operandInfo.isSufficient;
190 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
191
192 // unpackage operand dimensions
193 std::vector<uint32_t> dimensions;
194 dimensions.reserve(numberOfDimensions);
195 for (size_t i = 0; i < numberOfDimensions; ++i) {
196 // validate dimension
197 if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
198 LOG(ERROR) << "FMQ Result packet ill-formed";
199 return std::nullopt;
200 }
201
202 // unpackage dimension
203 const uint32_t dimension = data[index].operandDimensionValue();
204 index++;
205
206 // store result
207 dimensions.push_back(dimension);
208 }
209
210 // store result
211 outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
212 }
213
214 // validate execution timing
215 if (data[index].getDiscriminator() != discriminator::executionTiming) {
216 LOG(ERROR) << "FMQ Result packet ill-formed";
217 return std::nullopt;
218 }
219
220 // unpackage execution timing
221 const Timing timing = data[index].executionTiming();
222 index++;
223
224 // validate packet information
225 if (index != packetSize) {
226 LOG(ERROR) << "FMQ Result packet ill-formed";
227 return std::nullopt;
228 }
229
230 // return result
231 return std::make_tuple(errorStatus, std::move(outputShapes), timing);
232 }
233
legacyConvertResultCodeToErrorStatus(int resultCode)234 V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) {
235 return convertToV1_0(convertResultCodeToErrorStatus(resultCode));
236 }
237
238 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
create(size_t channelLength,std::chrono::microseconds pollingTimeWindow)239 ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
240 std::unique_ptr<FmqResultChannel> fmqResultChannel =
241 std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/true);
242 if (!fmqResultChannel->isValid()) {
243 LOG(ERROR) << "Unable to create ResultChannelReceiver";
244 return {nullptr, nullptr};
245 }
246
247 const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
248 return std::make_pair(
249 std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), pollingTimeWindow),
250 descriptor);
251 }
252
ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,std::chrono::microseconds pollingTimeWindow)253 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
254 std::chrono::microseconds pollingTimeWindow)
255 : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {}
256
257 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>>
getBlocking()258 ResultChannelReceiver::getBlocking() {
259 const auto packet = getPacketBlocking();
260 if (!packet) {
261 return std::nullopt;
262 }
263
264 return deserialize(*packet);
265 }
266
invalidate()267 void ResultChannelReceiver::invalidate() {
268 mValid = false;
269
270 // force unblock
271 // ExecutionBurstController waits on a result packet after sending a
272 // request. If the driver containing ExecutionBurstServer crashes, the
273 // controller may be waiting on the futex. This force unblock wakes up any
274 // thread waiting on the futex.
275 // TODO: look for a different/better way to signal/notify the futex to
276 // wake up any thread waiting on it
277 FmqResultDatum datum;
278 datum.packetInformation({/*.packetSize=*/0, /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE,
279 /*.numberOfOperands=*/0});
280 mFmqResultChannel->writeBlocking(&datum, 1);
281 }
282
getPacketBlocking()283 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
284 if (!mValid) {
285 return std::nullopt;
286 }
287
288 // First spend time polling if results are available in FMQ instead of
289 // waiting on the futex. Polling is more responsive (yielding lower
290 // latencies), but can take up more power, so only poll for a limited period
291 // of time.
292
293 auto& getCurrentTime = std::chrono::high_resolution_clock::now;
294 const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
295
296 while (getCurrentTime() < timeToStopPolling) {
297 // if class is being torn down, immediately return
298 if (!mValid.load(std::memory_order_relaxed)) {
299 return std::nullopt;
300 }
301
302 // Check if data is available. If it is, immediately retrieve it and
303 // return.
304 const size_t available = mFmqResultChannel->availableToRead();
305 if (available > 0) {
306 std::vector<FmqResultDatum> packet(available);
307 const bool success = mFmqResultChannel->read(packet.data(), available);
308 if (!success) {
309 LOG(ERROR) << "Error receiving packet";
310 return std::nullopt;
311 }
312 return std::make_optional(std::move(packet));
313 }
314 }
315
316 // If we get to this point, we either stopped polling because it was taking
317 // too long or polling was not allowed. Instead, perform a blocking call
318 // which uses a futex to save power.
319
320 // wait for result packet and read first element of result packet
321 FmqResultDatum datum;
322 bool success = mFmqResultChannel->readBlocking(&datum, 1);
323
324 // retrieve remaining elements
325 // NOTE: all of the data is already available at this point, so there's no
326 // need to do a blocking wait to wait for more data. This is known because
327 // in FMQ, all writes are published (made available) atomically. Currently,
328 // the producer always publishes the entire packet in one function call, so
329 // if the first element of the packet is available, the remaining elements
330 // are also available.
331 const size_t count = mFmqResultChannel->availableToRead();
332 std::vector<FmqResultDatum> packet(count + 1);
333 std::memcpy(&packet.front(), &datum, sizeof(datum));
334 success &= mFmqResultChannel->read(packet.data() + 1, count);
335
336 if (!mValid) {
337 return std::nullopt;
338 }
339
340 // ensure packet was successfully received
341 if (!success) {
342 LOG(ERROR) << "Error receiving packet";
343 return std::nullopt;
344 }
345
346 return std::make_optional(std::move(packet));
347 }
348
349 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
create(size_t channelLength)350 RequestChannelSender::create(size_t channelLength) {
351 std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
352 std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/true);
353 if (!fmqRequestChannel->isValid()) {
354 LOG(ERROR) << "Unable to create RequestChannelSender";
355 return {nullptr, nullptr};
356 }
357
358 const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
359 return std::make_pair(std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel)),
360 descriptor);
361 }
362
RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)363 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
364 : mFmqRequestChannel(std::move(fmqRequestChannel)) {}
365
send(const V1_0::Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)366 bool RequestChannelSender::send(const V1_0::Request& request, MeasureTiming measure,
367 const std::vector<int32_t>& slots) {
368 const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
369 return sendPacket(serialized);
370 }
371
sendPacket(const std::vector<FmqRequestDatum> & packet)372 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
373 if (!mValid) {
374 return false;
375 }
376
377 if (packet.size() > mFmqRequestChannel->availableToWrite()) {
378 LOG(ERROR)
379 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
380 return false;
381 }
382
383 // Always send the packet with "blocking" because this signals the futex and
384 // unblocks the consumer if it is waiting on the futex.
385 return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
386 }
387
invalidate()388 void RequestChannelSender::invalidate() {
389 mValid = false;
390 }
391
getMemories(const hidl_vec<int32_t> & slots,getMemories_cb cb)392 Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
393 const hidl_vec<int32_t>& slots, getMemories_cb cb) {
394 std::lock_guard<std::mutex> guard(mMutex);
395
396 // get all memories
397 hidl_vec<hidl_memory> memories(slots.size());
398 std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
399 return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
400 });
401
402 // ensure all memories are valid
403 if (!std::all_of(memories.begin(), memories.end(),
404 [](const hidl_memory& memory) { return memory.valid(); })) {
405 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
406 return Void();
407 }
408
409 // return successful
410 cb(V1_0::ErrorStatus::NONE, std::move(memories));
411 return Void();
412 }
413
getSlots(const hidl_vec<hidl_memory> & memories,const std::vector<intptr_t> & keys)414 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
415 const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) {
416 std::lock_guard<std::mutex> guard(mMutex);
417
418 // retrieve (or bind) all slots corresponding to memories
419 std::vector<int32_t> slots;
420 slots.reserve(memories.size());
421 for (size_t i = 0; i < memories.size(); ++i) {
422 slots.push_back(getSlotLocked(memories[i], keys[i]));
423 }
424 return slots;
425 }
426
freeMemory(intptr_t key)427 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
428 intptr_t key) {
429 std::lock_guard<std::mutex> guard(mMutex);
430
431 auto iter = mMemoryIdToSlot.find(key);
432 if (iter == mMemoryIdToSlot.end()) {
433 return {false, 0};
434 }
435 const int32_t slot = iter->second;
436 mMemoryIdToSlot.erase(key);
437 mMemoryCache[slot] = {};
438 mFreeSlots.push(slot);
439 return {true, slot};
440 }
441
getSlotLocked(const hidl_memory & memory,intptr_t key)442 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
443 intptr_t key) {
444 auto iter = mMemoryIdToSlot.find(key);
445 if (iter == mMemoryIdToSlot.end()) {
446 const int32_t slot = allocateSlotLocked();
447 mMemoryIdToSlot[key] = slot;
448 mMemoryCache[slot] = memory;
449 return slot;
450 } else {
451 const int32_t slot = iter->second;
452 return slot;
453 }
454 }
455
allocateSlotLocked()456 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
457 constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
458
459 // if there is a free slot, use it
460 if (mFreeSlots.size() > 0) {
461 const int32_t slot = mFreeSlots.top();
462 mFreeSlots.pop();
463 return slot;
464 }
465
466 // otherwise use a slot for the first time
467 CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
468 const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
469 mMemoryCache.emplace_back();
470
471 return slot;
472 }
473
create(const sp<V1_2::IPreparedModel> & preparedModel,std::chrono::microseconds pollingTimeWindow)474 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
475 const sp<V1_2::IPreparedModel>& preparedModel,
476 std::chrono::microseconds pollingTimeWindow) {
477 // check inputs
478 if (preparedModel == nullptr) {
479 LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
480 return nullptr;
481 }
482
483 // create callback object
484 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
485
486 // create FMQ objects
487 auto [requestChannelSenderTemp, requestChannelDescriptor] =
488 RequestChannelSender::create(kExecutionBurstChannelLength);
489 auto [resultChannelReceiverTemp, resultChannelDescriptor] =
490 ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
491 std::shared_ptr<RequestChannelSender> requestChannelSender =
492 std::move(requestChannelSenderTemp);
493 std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
494 std::move(resultChannelReceiverTemp);
495
496 // check FMQ objects
497 if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
498 !resultChannelDescriptor) {
499 LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
500 return nullptr;
501 }
502
503 // configure burst
504 V1_0::ErrorStatus errorStatus;
505 sp<IBurstContext> burstContext;
506 const Return<void> ret = preparedModel->configureExecutionBurst(
507 callback, *requestChannelDescriptor, *resultChannelDescriptor,
508 [&errorStatus, &burstContext](V1_0::ErrorStatus status,
509 const sp<IBurstContext>& context) {
510 errorStatus = status;
511 burstContext = context;
512 });
513
514 // check burst
515 if (!ret.isOk()) {
516 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
517 << ret.description();
518 return nullptr;
519 }
520 if (errorStatus != V1_0::ErrorStatus::NONE) {
521 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
522 << toString(errorStatus);
523 return nullptr;
524 }
525 if (burstContext == nullptr) {
526 LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
527 return nullptr;
528 }
529
530 // create death handler object
531 BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
532 resultChannelReceiver] {
533 requestChannelSender->invalidate();
534 resultChannelReceiver->invalidate();
535 };
536 const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
537
538 // linkToDeath registers a callback that will be invoked on service death to
539 // proactively handle service crashes. If the linkToDeath call fails,
540 // asynchronous calls are susceptible to hangs if the service crashes before
541 // providing the response.
542 const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
543 if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
544 LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
545 "for the IBurstContext object.";
546 return nullptr;
547 }
548
549 // make and return controller
550 return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
551 burstContext, callback, deathHandler);
552 }
553
ExecutionBurstController(const std::shared_ptr<RequestChannelSender> & requestChannelSender,const std::shared_ptr<ResultChannelReceiver> & resultChannelReceiver,const sp<IBurstContext> & burstContext,const sp<ExecutionBurstCallback> & callback,const sp<hidl_death_recipient> & deathHandler)554 ExecutionBurstController::ExecutionBurstController(
555 const std::shared_ptr<RequestChannelSender>& requestChannelSender,
556 const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
557 const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
558 const sp<hidl_death_recipient>& deathHandler)
559 : mRequestChannelSender(requestChannelSender),
560 mResultChannelReceiver(resultChannelReceiver),
561 mBurstContext(burstContext),
562 mMemoryCache(callback),
563 mDeathHandler(deathHandler) {}
564
~ExecutionBurstController()565 ExecutionBurstController::~ExecutionBurstController() {
566 // It is safe to ignore any errors resulting from this unlinkToDeath call
567 // because the ExecutionBurstController object is already being destroyed
568 // and its underlying IBurstContext object is no longer being used by the NN
569 // runtime.
570 if (mDeathHandler) {
571 mBurstContext->unlinkToDeath(mDeathHandler).isOk();
572 }
573 }
574
getExecutionResult(V1_0::ErrorStatus status,std::vector<OutputShape> outputShapes,Timing timing,bool fallback)575 static std::tuple<int, std::vector<OutputShape>, Timing, bool> getExecutionResult(
576 V1_0::ErrorStatus status, std::vector<OutputShape> outputShapes, Timing timing,
577 bool fallback) {
578 auto [n, checkedOutputShapes, checkedTiming] =
579 getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
580 return {n, std::move(checkedOutputShapes), checkedTiming, fallback};
581 }
582
compute(const V1_0::Request & request,MeasureTiming measure,const std::vector<intptr_t> & memoryIds)583 std::tuple<int, std::vector<OutputShape>, Timing, bool> ExecutionBurstController::compute(
584 const V1_0::Request& request, MeasureTiming measure,
585 const std::vector<intptr_t>& memoryIds) {
586 // This is the first point when we know an execution is occurring, so begin
587 // to collect systraces. Note that the first point we can begin collecting
588 // systraces in ExecutionBurstServer is when the RequestChannelReceiver
589 // realizes there is data in the FMQ, so ExecutionBurstServer collects
590 // systraces at different points in the code.
591 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
592
593 std::lock_guard<std::mutex> guard(mMutex);
594
595 // send request packet
596 const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
597 const bool success = mRequestChannelSender->send(request, measure, slots);
598 if (!success) {
599 LOG(ERROR) << "Error sending FMQ packet";
600 // only use fallback execution path if the packet could not be sent
601 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming,
602 /*fallback=*/true);
603 }
604
605 // get result packet
606 const auto result = mResultChannelReceiver->getBlocking();
607 if (!result) {
608 LOG(ERROR) << "Error retrieving FMQ packet";
609 // only use fallback execution path if the packet could not be sent
610 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming,
611 /*fallback=*/false);
612 }
613
614 // unpack results and return (only use fallback execution path if the
615 // packet could not be sent)
616 auto [status, outputShapes, timing] = std::move(*result);
617 return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
618 }
619
freeMemory(intptr_t key)620 void ExecutionBurstController::freeMemory(intptr_t key) {
621 std::lock_guard<std::mutex> guard(mMutex);
622
623 bool valid;
624 int32_t slot;
625 std::tie(valid, slot) = mMemoryCache->freeMemory(key);
626 if (valid) {
627 mBurstContext->freeMemory(slot).isOk();
628 }
629 }
630
631 } // namespace android::nn
632