1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ExecutionBurstController"
18 
19 #include "ExecutionBurstController.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <algorithm>
24 #include <cstring>
25 #include <limits>
26 #include <memory>
27 #include <string>
28 #include <tuple>
29 #include <utility>
30 #include <vector>
31 
32 #include "HalInterfaces.h"
33 #include "Tracing.h"
34 #include "Utils.h"
35 
36 namespace android::nn {
37 namespace {
38 
39 using namespace hal;
40 
41 using V1_2::FmqRequestDatum;
42 using V1_2::FmqResultDatum;
43 using V1_2::IBurstCallback;
44 using V1_2::IBurstContext;
45 using FmqRequestDescriptor = hardware::MQDescriptorSync<FmqRequestDatum>;
46 using FmqResultDescriptor = hardware::MQDescriptorSync<FmqResultDatum>;
47 
48 constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
49                               std::numeric_limits<uint64_t>::max()};
50 
51 class BurstContextDeathHandler : public hidl_death_recipient {
52    public:
53     using Callback = std::function<void()>;
54 
BurstContextDeathHandler(const Callback & onDeathCallback)55     BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
56         CHECK(onDeathCallback != nullptr);
57     }
58 
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)59     void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
60         LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
61         mOnDeathCallback();
62     }
63 
64    private:
65     const Callback mOnDeathCallback;
66 };
67 
68 }  // anonymous namespace
69 
70 // serialize a request into a packet
serialize(const V1_0::Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)71 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, MeasureTiming measure,
72                                        const std::vector<int32_t>& slots) {
73     // count how many elements need to be sent for a request
74     size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
75     for (const auto& input : request.inputs) {
76         count += input.dimensions.size();
77     }
78     for (const auto& output : request.outputs) {
79         count += output.dimensions.size();
80     }
81 
82     // create buffer to temporarily store elements
83     std::vector<FmqRequestDatum> data;
84     data.reserve(count);
85 
86     // package packetInfo
87     {
88         FmqRequestDatum datum;
89         datum.packetInformation(
90                 {/*.packetSize=*/static_cast<uint32_t>(count),
91                  /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
92                  /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
93                  /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
94         data.push_back(datum);
95     }
96 
97     // package input data
98     for (const auto& input : request.inputs) {
99         // package operand information
100         FmqRequestDatum datum;
101         datum.inputOperandInformation(
102                 {/*.hasNoValue=*/input.hasNoValue,
103                  /*.location=*/input.location,
104                  /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
105         data.push_back(datum);
106 
107         // package operand dimensions
108         for (uint32_t dimension : input.dimensions) {
109             FmqRequestDatum datum;
110             datum.inputOperandDimensionValue(dimension);
111             data.push_back(datum);
112         }
113     }
114 
115     // package output data
116     for (const auto& output : request.outputs) {
117         // package operand information
118         FmqRequestDatum datum;
119         datum.outputOperandInformation(
120                 {/*.hasNoValue=*/output.hasNoValue,
121                  /*.location=*/output.location,
122                  /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
123         data.push_back(datum);
124 
125         // package operand dimensions
126         for (uint32_t dimension : output.dimensions) {
127             FmqRequestDatum datum;
128             datum.outputOperandDimensionValue(dimension);
129             data.push_back(datum);
130         }
131     }
132 
133     // package pool identifier
134     for (int32_t slot : slots) {
135         FmqRequestDatum datum;
136         datum.poolIdentifier(slot);
137         data.push_back(datum);
138     }
139 
140     // package measureTiming
141     {
142         FmqRequestDatum datum;
143         datum.measureTiming(measure);
144         data.push_back(datum);
145     }
146 
147     // return packet
148     return data;
149 }
150 
151 // deserialize a packet into the result
deserialize(const std::vector<FmqResultDatum> & data)152 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>> deserialize(
153         const std::vector<FmqResultDatum>& data) {
154     using discriminator = FmqResultDatum::hidl_discriminator;
155 
156     std::vector<OutputShape> outputShapes;
157     size_t index = 0;
158 
159     // validate packet information
160     if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
161         LOG(ERROR) << "FMQ Result packet ill-formed";
162         return std::nullopt;
163     }
164 
165     // unpackage packet information
166     const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
167     index++;
168     const uint32_t packetSize = packetInfo.packetSize;
169     const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
170     const uint32_t numberOfOperands = packetInfo.numberOfOperands;
171 
172     // verify packet size
173     if (data.size() != packetSize) {
174         LOG(ERROR) << "FMQ Result packet ill-formed";
175         return std::nullopt;
176     }
177 
178     // unpackage operands
179     for (size_t operand = 0; operand < numberOfOperands; ++operand) {
180         // validate operand information
181         if (data[index].getDiscriminator() != discriminator::operandInformation) {
182             LOG(ERROR) << "FMQ Result packet ill-formed";
183             return std::nullopt;
184         }
185 
186         // unpackage operand information
187         const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
188         index++;
189         const bool isSufficient = operandInfo.isSufficient;
190         const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
191 
192         // unpackage operand dimensions
193         std::vector<uint32_t> dimensions;
194         dimensions.reserve(numberOfDimensions);
195         for (size_t i = 0; i < numberOfDimensions; ++i) {
196             // validate dimension
197             if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
198                 LOG(ERROR) << "FMQ Result packet ill-formed";
199                 return std::nullopt;
200             }
201 
202             // unpackage dimension
203             const uint32_t dimension = data[index].operandDimensionValue();
204             index++;
205 
206             // store result
207             dimensions.push_back(dimension);
208         }
209 
210         // store result
211         outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
212     }
213 
214     // validate execution timing
215     if (data[index].getDiscriminator() != discriminator::executionTiming) {
216         LOG(ERROR) << "FMQ Result packet ill-formed";
217         return std::nullopt;
218     }
219 
220     // unpackage execution timing
221     const Timing timing = data[index].executionTiming();
222     index++;
223 
224     // validate packet information
225     if (index != packetSize) {
226         LOG(ERROR) << "FMQ Result packet ill-formed";
227         return std::nullopt;
228     }
229 
230     // return result
231     return std::make_tuple(errorStatus, std::move(outputShapes), timing);
232 }
233 
legacyConvertResultCodeToErrorStatus(int resultCode)234 V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) {
235     return convertToV1_0(convertResultCodeToErrorStatus(resultCode));
236 }
237 
238 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
create(size_t channelLength,std::chrono::microseconds pollingTimeWindow)239 ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
240     std::unique_ptr<FmqResultChannel> fmqResultChannel =
241             std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/true);
242     if (!fmqResultChannel->isValid()) {
243         LOG(ERROR) << "Unable to create ResultChannelReceiver";
244         return {nullptr, nullptr};
245     }
246 
247     const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
248     return std::make_pair(
249             std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), pollingTimeWindow),
250             descriptor);
251 }
252 
ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,std::chrono::microseconds pollingTimeWindow)253 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
254                                              std::chrono::microseconds pollingTimeWindow)
255     : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {}
256 
257 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<OutputShape>, Timing>>
getBlocking()258 ResultChannelReceiver::getBlocking() {
259     const auto packet = getPacketBlocking();
260     if (!packet) {
261         return std::nullopt;
262     }
263 
264     return deserialize(*packet);
265 }
266 
invalidate()267 void ResultChannelReceiver::invalidate() {
268     mValid = false;
269 
270     // force unblock
271     // ExecutionBurstController waits on a result packet after sending a
272     // request. If the driver containing ExecutionBurstServer crashes, the
273     // controller may be waiting on the futex. This force unblock wakes up any
274     // thread waiting on the futex.
275     // TODO: look for a different/better way to signal/notify the futex to
276     // wake up any thread waiting on it
277     FmqResultDatum datum;
278     datum.packetInformation({/*.packetSize=*/0, /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE,
279                              /*.numberOfOperands=*/0});
280     mFmqResultChannel->writeBlocking(&datum, 1);
281 }
282 
getPacketBlocking()283 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
284     if (!mValid) {
285         return std::nullopt;
286     }
287 
288     // First spend time polling if results are available in FMQ instead of
289     // waiting on the futex. Polling is more responsive (yielding lower
290     // latencies), but can take up more power, so only poll for a limited period
291     // of time.
292 
293     auto& getCurrentTime = std::chrono::high_resolution_clock::now;
294     const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
295 
296     while (getCurrentTime() < timeToStopPolling) {
297         // if class is being torn down, immediately return
298         if (!mValid.load(std::memory_order_relaxed)) {
299             return std::nullopt;
300         }
301 
302         // Check if data is available. If it is, immediately retrieve it and
303         // return.
304         const size_t available = mFmqResultChannel->availableToRead();
305         if (available > 0) {
306             std::vector<FmqResultDatum> packet(available);
307             const bool success = mFmqResultChannel->read(packet.data(), available);
308             if (!success) {
309                 LOG(ERROR) << "Error receiving packet";
310                 return std::nullopt;
311             }
312             return std::make_optional(std::move(packet));
313         }
314     }
315 
316     // If we get to this point, we either stopped polling because it was taking
317     // too long or polling was not allowed. Instead, perform a blocking call
318     // which uses a futex to save power.
319 
320     // wait for result packet and read first element of result packet
321     FmqResultDatum datum;
322     bool success = mFmqResultChannel->readBlocking(&datum, 1);
323 
324     // retrieve remaining elements
325     // NOTE: all of the data is already available at this point, so there's no
326     // need to do a blocking wait to wait for more data. This is known because
327     // in FMQ, all writes are published (made available) atomically. Currently,
328     // the producer always publishes the entire packet in one function call, so
329     // if the first element of the packet is available, the remaining elements
330     // are also available.
331     const size_t count = mFmqResultChannel->availableToRead();
332     std::vector<FmqResultDatum> packet(count + 1);
333     std::memcpy(&packet.front(), &datum, sizeof(datum));
334     success &= mFmqResultChannel->read(packet.data() + 1, count);
335 
336     if (!mValid) {
337         return std::nullopt;
338     }
339 
340     // ensure packet was successfully received
341     if (!success) {
342         LOG(ERROR) << "Error receiving packet";
343         return std::nullopt;
344     }
345 
346     return std::make_optional(std::move(packet));
347 }
348 
349 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
create(size_t channelLength)350 RequestChannelSender::create(size_t channelLength) {
351     std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
352             std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/true);
353     if (!fmqRequestChannel->isValid()) {
354         LOG(ERROR) << "Unable to create RequestChannelSender";
355         return {nullptr, nullptr};
356     }
357 
358     const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
359     return std::make_pair(std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel)),
360                           descriptor);
361 }
362 
RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)363 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
364     : mFmqRequestChannel(std::move(fmqRequestChannel)) {}
365 
send(const V1_0::Request & request,MeasureTiming measure,const std::vector<int32_t> & slots)366 bool RequestChannelSender::send(const V1_0::Request& request, MeasureTiming measure,
367                                 const std::vector<int32_t>& slots) {
368     const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
369     return sendPacket(serialized);
370 }
371 
sendPacket(const std::vector<FmqRequestDatum> & packet)372 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
373     if (!mValid) {
374         return false;
375     }
376 
377     if (packet.size() > mFmqRequestChannel->availableToWrite()) {
378         LOG(ERROR)
379                 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
380         return false;
381     }
382 
383     // Always send the packet with "blocking" because this signals the futex and
384     // unblocks the consumer if it is waiting on the futex.
385     return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
386 }
387 
invalidate()388 void RequestChannelSender::invalidate() {
389     mValid = false;
390 }
391 
getMemories(const hidl_vec<int32_t> & slots,getMemories_cb cb)392 Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
393         const hidl_vec<int32_t>& slots, getMemories_cb cb) {
394     std::lock_guard<std::mutex> guard(mMutex);
395 
396     // get all memories
397     hidl_vec<hidl_memory> memories(slots.size());
398     std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
399         return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
400     });
401 
402     // ensure all memories are valid
403     if (!std::all_of(memories.begin(), memories.end(),
404                      [](const hidl_memory& memory) { return memory.valid(); })) {
405         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
406         return Void();
407     }
408 
409     // return successful
410     cb(V1_0::ErrorStatus::NONE, std::move(memories));
411     return Void();
412 }
413 
getSlots(const hidl_vec<hidl_memory> & memories,const std::vector<intptr_t> & keys)414 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
415         const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) {
416     std::lock_guard<std::mutex> guard(mMutex);
417 
418     // retrieve (or bind) all slots corresponding to memories
419     std::vector<int32_t> slots;
420     slots.reserve(memories.size());
421     for (size_t i = 0; i < memories.size(); ++i) {
422         slots.push_back(getSlotLocked(memories[i], keys[i]));
423     }
424     return slots;
425 }
426 
freeMemory(intptr_t key)427 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
428         intptr_t key) {
429     std::lock_guard<std::mutex> guard(mMutex);
430 
431     auto iter = mMemoryIdToSlot.find(key);
432     if (iter == mMemoryIdToSlot.end()) {
433         return {false, 0};
434     }
435     const int32_t slot = iter->second;
436     mMemoryIdToSlot.erase(key);
437     mMemoryCache[slot] = {};
438     mFreeSlots.push(slot);
439     return {true, slot};
440 }
441 
getSlotLocked(const hidl_memory & memory,intptr_t key)442 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
443                                                                         intptr_t key) {
444     auto iter = mMemoryIdToSlot.find(key);
445     if (iter == mMemoryIdToSlot.end()) {
446         const int32_t slot = allocateSlotLocked();
447         mMemoryIdToSlot[key] = slot;
448         mMemoryCache[slot] = memory;
449         return slot;
450     } else {
451         const int32_t slot = iter->second;
452         return slot;
453     }
454 }
455 
allocateSlotLocked()456 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
457     constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
458 
459     // if there is a free slot, use it
460     if (mFreeSlots.size() > 0) {
461         const int32_t slot = mFreeSlots.top();
462         mFreeSlots.pop();
463         return slot;
464     }
465 
466     // otherwise use a slot for the first time
467     CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
468     const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
469     mMemoryCache.emplace_back();
470 
471     return slot;
472 }
473 
create(const sp<V1_2::IPreparedModel> & preparedModel,std::chrono::microseconds pollingTimeWindow)474 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
475         const sp<V1_2::IPreparedModel>& preparedModel,
476         std::chrono::microseconds pollingTimeWindow) {
477     // check inputs
478     if (preparedModel == nullptr) {
479         LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
480         return nullptr;
481     }
482 
483     // create callback object
484     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
485 
486     // create FMQ objects
487     auto [requestChannelSenderTemp, requestChannelDescriptor] =
488             RequestChannelSender::create(kExecutionBurstChannelLength);
489     auto [resultChannelReceiverTemp, resultChannelDescriptor] =
490             ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
491     std::shared_ptr<RequestChannelSender> requestChannelSender =
492             std::move(requestChannelSenderTemp);
493     std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
494             std::move(resultChannelReceiverTemp);
495 
496     // check FMQ objects
497     if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
498         !resultChannelDescriptor) {
499         LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
500         return nullptr;
501     }
502 
503     // configure burst
504     V1_0::ErrorStatus errorStatus;
505     sp<IBurstContext> burstContext;
506     const Return<void> ret = preparedModel->configureExecutionBurst(
507             callback, *requestChannelDescriptor, *resultChannelDescriptor,
508             [&errorStatus, &burstContext](V1_0::ErrorStatus status,
509                                           const sp<IBurstContext>& context) {
510                 errorStatus = status;
511                 burstContext = context;
512             });
513 
514     // check burst
515     if (!ret.isOk()) {
516         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
517                    << ret.description();
518         return nullptr;
519     }
520     if (errorStatus != V1_0::ErrorStatus::NONE) {
521         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
522                    << toString(errorStatus);
523         return nullptr;
524     }
525     if (burstContext == nullptr) {
526         LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
527         return nullptr;
528     }
529 
530     // create death handler object
531     BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
532                                                           resultChannelReceiver] {
533         requestChannelSender->invalidate();
534         resultChannelReceiver->invalidate();
535     };
536     const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
537 
538     // linkToDeath registers a callback that will be invoked on service death to
539     // proactively handle service crashes. If the linkToDeath call fails,
540     // asynchronous calls are susceptible to hangs if the service crashes before
541     // providing the response.
542     const Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
543     if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
544         LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
545                       "for the IBurstContext object.";
546         return nullptr;
547     }
548 
549     // make and return controller
550     return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
551                                                       burstContext, callback, deathHandler);
552 }
553 
ExecutionBurstController(const std::shared_ptr<RequestChannelSender> & requestChannelSender,const std::shared_ptr<ResultChannelReceiver> & resultChannelReceiver,const sp<IBurstContext> & burstContext,const sp<ExecutionBurstCallback> & callback,const sp<hidl_death_recipient> & deathHandler)554 ExecutionBurstController::ExecutionBurstController(
555         const std::shared_ptr<RequestChannelSender>& requestChannelSender,
556         const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
557         const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
558         const sp<hidl_death_recipient>& deathHandler)
559     : mRequestChannelSender(requestChannelSender),
560       mResultChannelReceiver(resultChannelReceiver),
561       mBurstContext(burstContext),
562       mMemoryCache(callback),
563       mDeathHandler(deathHandler) {}
564 
~ExecutionBurstController()565 ExecutionBurstController::~ExecutionBurstController() {
566     // It is safe to ignore any errors resulting from this unlinkToDeath call
567     // because the ExecutionBurstController object is already being destroyed
568     // and its underlying IBurstContext object is no longer being used by the NN
569     // runtime.
570     if (mDeathHandler) {
571         mBurstContext->unlinkToDeath(mDeathHandler).isOk();
572     }
573 }
574 
getExecutionResult(V1_0::ErrorStatus status,std::vector<OutputShape> outputShapes,Timing timing,bool fallback)575 static std::tuple<int, std::vector<OutputShape>, Timing, bool> getExecutionResult(
576         V1_0::ErrorStatus status, std::vector<OutputShape> outputShapes, Timing timing,
577         bool fallback) {
578     auto [n, checkedOutputShapes, checkedTiming] =
579             getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
580     return {n, std::move(checkedOutputShapes), checkedTiming, fallback};
581 }
582 
compute(const V1_0::Request & request,MeasureTiming measure,const std::vector<intptr_t> & memoryIds)583 std::tuple<int, std::vector<OutputShape>, Timing, bool> ExecutionBurstController::compute(
584         const V1_0::Request& request, MeasureTiming measure,
585         const std::vector<intptr_t>& memoryIds) {
586     // This is the first point when we know an execution is occurring, so begin
587     // to collect systraces. Note that the first point we can begin collecting
588     // systraces in ExecutionBurstServer is when the RequestChannelReceiver
589     // realizes there is data in the FMQ, so ExecutionBurstServer collects
590     // systraces at different points in the code.
591     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
592 
593     std::lock_guard<std::mutex> guard(mMutex);
594 
595     // send request packet
596     const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
597     const bool success = mRequestChannelSender->send(request, measure, slots);
598     if (!success) {
599         LOG(ERROR) << "Error sending FMQ packet";
600         // only use fallback execution path if the packet could not be sent
601         return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming,
602                                   /*fallback=*/true);
603     }
604 
605     // get result packet
606     const auto result = mResultChannelReceiver->getBlocking();
607     if (!result) {
608         LOG(ERROR) << "Error retrieving FMQ packet";
609         // only use fallback execution path if the packet could not be sent
610         return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming,
611                                   /*fallback=*/false);
612     }
613 
614     // unpack results and return (only use fallback execution path if the
615     // packet could not be sent)
616     auto [status, outputShapes, timing] = std::move(*result);
617     return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
618 }
619 
freeMemory(intptr_t key)620 void ExecutionBurstController::freeMemory(intptr_t key) {
621     std::lock_guard<std::mutex> guard(mMutex);
622 
623     bool valid;
624     int32_t slot;
625     std::tie(valid, slot) = mMemoryCache->freeMemory(key);
626     if (valid) {
627         mBurstContext->freeMemory(slot).isOk();
628     }
629 }
630 
631 }  // namespace android::nn
632