1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Classes used to plan how to execute a model across multiple devices.
18 
19 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_EXECUTION_PLAN_H
20 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_EXECUTION_PLAN_H
21 
22 #include <android-base/logging.h>
23 #include <openssl/sha.h>
24 
25 #include <chrono>
26 #include <map>
27 #include <memory>
28 #include <ostream>
29 #include <set>
30 #include <string>
31 #include <unordered_map>
32 #include <utility>
33 #include <variant>
34 #include <vector>
35 
36 #include "HalInterfaces.h"
37 #include "Memory.h"
38 #include "ModelArgumentInfo.h"
39 #include "ModelBuilder.h"
40 #include "NeuralNetworks.h"
41 #include "TokenHasher.h"
42 #include "Utils.h"
43 
44 namespace android {
45 namespace nn {
46 
47 class BurstBuilder;
48 class CompilationBuilder;
49 class Device;
50 class ExecutionBuilder;
51 class ExecutionBurstController;
52 class ExecutionPlan;
53 class Memory;
54 class PreparedModel;
55 class StepExecutor;
56 
57 struct ConstantReferenceLocation;
58 
59 // NNAPI Control Flow allows referring to an NNAPI model inside another NNAPI
60 // model using OperandType::SUBGRAPH. For example, an IF operation within a
61 // model mey refer to two other models corresponding to then and else branches.
62 //
63 // The partitioning process transforms this nested representation into a list
64 // of LogicalSteps.
65 //
66 // The following terms are used:
67 // - The main model is the top-level model being compiled (not referenced by any
68 //   OperandType::SUBGRAPH operand within the compilation).
69 // - A referenced model is a non-top-level model being compiled (referenced by
70 //   at least one OperandType::SUBGRAPH operand within the set of models being
71 //   compiled).
72 // - A source model is either the main model or a referenced model.
73 // - A step model is a model excerpted from a source model during the
74 //   partitioning process.
75 // - A partition is a LogicalStep representing at least one operation of a
76 //   source model. In particular, ExecutionStep represents a step model, IfStep
77 //   represents an IF operation, WhileStep represents a WHILE operation.
78 //   A GotoStep is not a partition.
79 // - A partition boundary operand is a source model operand that is an input or
80 //   output of a partition. For ExecutionStep, the inputs and outputs of the
81 //   step model are boundary operands; for IfStep and WhileStep, the inputs and
82 //   outputs of the corresponding operation are boundary operands.
83 //
84 // Referenced models can be sources of parition boundary operands. For example,
85 // this happens when a referenced model is paritioned into one or more
86 // LogicalSteps.
87 //
88 // (model index, operand index within model)
89 typedef std::pair<uint32_t, uint32_t> SourceOperandIndex;
90 
91 // A collection of source models.
92 class SourceModels {
93    public:
addModel(const ModelBuilder * model)94     uint32_t addModel(const ModelBuilder* model) {
95         uint32_t modelIndex = mModels.size();
96         mModels.push_back(model);
97         return modelIndex;
98     }
99 
getModel(uint32_t index)100     const ModelBuilder* getModel(uint32_t index) const { return mModels[index]; }
101 
size()102     uint32_t size() const { return mModels.size(); }
103 
104    private:
105     std::vector<const ModelBuilder*> mModels;
106 };
107 
108 // An excerpt of a source model to be run by a specific device.
109 class ExecutionStep {
110    public:
111     typedef std::vector<std::pair<uint32_t, uint32_t>> RemapVectorType;
112     typedef std::set<std::pair<uint32_t, uint32_t>> StepModelOutputSetType;
113 
114     enum OperandKind { INPUT, OUTPUT };
115 
116     ExecutionStep(ExecutionPlan* plan, uint32_t stepIndex, uint32_t sourceModelIndex,
117                   std::shared_ptr<Device> device);
118 
119     int addOperation(int operationIndex);
120     int addOperand(uint32_t sourceOperandIndex, uint32_t* stepOperandIndex, OperandKind kind);
121 
122     // Each container entry is of the form (source model operand index, step model operand index)
getModelInputs()123     const RemapVectorType& getModelInputs() const { return mModelInputs; }
getModelOutputs()124     const RemapVectorType& getModelOutputs() const { return mModelOutputs; }
getTempsAsStepModelInputs()125     const RemapVectorType& getTempsAsStepModelInputs() const { return mTempsAsStepModelInputs; }
getTempsAsStepModelOutputs()126     const StepModelOutputSetType& getTempsAsStepModelOutputs() const {
127         return mTempsAsStepModelOutputs;
128     }
getOutputsAsStepModelInputs()129     const RemapVectorType& getOutputsAsStepModelInputs() const { return mOutputsAsStepModelInputs; }
getInputIndexStepModelToMainModel()130     const std::vector<uint32_t>& getInputIndexStepModelToMainModel() const {
131         return mInputIndexStepModelToMainModel;
132     }
getOutputIndexStepModelToMainModel()133     const std::vector<uint32_t>& getOutputIndexStepModelToMainModel() const {
134         return mOutputIndexStepModelToMainModel;
135     }
getOutputsAsStepModelInputsIndexToMainModel()136     const std::vector<uint32_t>& getOutputsAsStepModelInputsIndexToMainModel() const {
137         return mOutputsAsStepModelInputsIndexToMainModel;
138     }
139 
getSourceModelIndex()140     uint32_t getSourceModelIndex() const { return mSourceModelIndex; }
141 
142     void recordTempAsStepModelOutput(uint32_t stepOperandIndex);
143 
144     // If this step has a step model output of unknown size, sets
145     // *hasOutputOfUnknownSize to true; otherwise, leaves it
146     // unchanged.
147     int finishStepModel(const ModelBuilder* mainModel, bool* hasOutputOfUnknownSize,
148                         int32_t executionPreference, int32_t priority);
149 
getStepModel()150     const ModelBuilder* getStepModel() const { return &mStepModel; }
getDevice()151     std::shared_ptr<Device> getDevice() const { return mDevice; }
152 
153     // only available after calling finishStepModel()
getPreparedStepModel()154     std::shared_ptr<PreparedModel> getPreparedStepModel() const { return mPreparedStepModel; }
155 
156     // Map inputs and outputs from ExecutionBuilder to StepExecutor.
157     //
158     // This method only reads map entries for which the first element of
159     // SourceOperandIndex is mSourceModelIndex.
160     void mapInputsAndOutputs(
161             std::shared_ptr<StepExecutor> stepExecutor, const Memory* temporaryMemory,
162             const std::map<SourceOperandIndex, uint32_t>& sourceOperandToOffsetOfTemporary,
163             const std::map<SourceOperandIndex, uint32_t>& sourceOperandToInputIndex,
164             const std::map<SourceOperandIndex, uint32_t>& sourceOperandToOutputIndex,
165             const std::map<SourceOperandIndex, ConstantReferenceLocation>&
166                     sourceOperandToConstantReference) const;
167 
168     void dump() const;
169 
170     // For test only, get the transformed cache token.
forTest_getCacheToken()171     const uint8_t* forTest_getCacheToken() const { return mToken.getCacheToken(); }
172 
173    private:
174     void logStepModel() const;
175     const ModelBuilder* getSourceModel() const;
176 
177     // TODO: Some of the data is working state information that
178     // shouldn't be needed after we've constructed but not executed
179     // the step.
180 
181     ExecutionPlan* mPlan;
182     uint32_t mIndex;  // index of step within plan
183     uint32_t mSourceModelIndex;
184     ModelBuilder mStepModel;  // An excerpt of a source model to be run by one device.
185     std::shared_ptr<Device> mDevice;
186     std::shared_ptr<PreparedModel> mPreparedStepModel;
187 
188     // All inputs of this step model:
189     //     (source model operand index, step model operand index)
190     //
191     // Depending on whether the source operand is an input or output of the main
192     // model, the memory should be mapped using
193     // ExecutionPlan::CompoundBody::mSourceOperandToInputIndex,
194     // ExecutionPlan::Controller::mSourceOperandToOffsetOfTemporary, or
195     // ExecutionPlan::CompoundBody::mSourceOperandToOutputIndex.
196     RemapVectorType mStepModelInputs;
197     // All outputs of this step model:
198     //     (source model operand index, step model operand index)
199     //
200     // Depending on whether the source operand is an output of the main model,
201     // the memory should be mapped using
202     // ExecutionPlan::CompoundBody::mSourceOperandToOutputIndex or
203     // ExecutionPlan::Controller::mSourceOperandToOffsetOfTemporary.
204     //
205     // mOutputIndexStepModelToMainModel relies on mModelOutputs being a prefix of
206     // mStepModelOutputs.
207     RemapVectorType mStepModelOutputs;
208     // Inputs of main model that are also inputs of this step model:
209     //     (main model operand index, step model operand index)
210     RemapVectorType mModelInputs;
211     // Outputs of main model that are also outputs of this step model:
212     //     (main model operand index, step model operand index)
213     RemapVectorType mModelOutputs;
214     // Temporaries of source model that are inputs of this step model:
215     //     (source model operand index, step model operand index)
216     RemapVectorType mTempsAsStepModelInputs;
217     // Temporaries of source model that are outputs of this step model:
218     //     (source model operand index, step model operand index)
219     StepModelOutputSetType mTempsAsStepModelOutputs;
220     // Outputs of main model that are inputs of this step model:
221     //     (main model operand index, step model operand index)
222     RemapVectorType mOutputsAsStepModelInputs;
223     // Converts operand indexes from the source model to the step model.
224     std::unordered_map<uint32_t, uint32_t> mOperandMap;
225     // Converts input indexes from the step model to the main model
226     // (these are input indexes, not operand indexes).  This vector
227     // only describes inputs of the step model that are also inputs of
228     // the main model -- that is, mModelInputs but not mTempsAsStepModelInputs.
229     std::vector<uint32_t> mInputIndexStepModelToMainModel;
230     // Converts output indexes from the step model to the main model
231     // (these are output indexes, not operand indexes).  This vector
232     // only describes outputs of the step model that are also outputs of
233     // the main model -- that is, mModelOutputs but not
234     // mTempsAsStepModelOutputs.
235     std::vector<uint32_t> mOutputIndexStepModelToMainModel;
236     // Converts indexes into mOutputsAsStepModelInputs to indexes into
237     // main model outputs (these are input and output indexes, not
238     // operand indexes).  To be specific, if the main model outputs
239     // are mainModelOutputs,
240     //
241     //     mOutputsAsStepModelInputsIndexToMainModel.size() ==
242     //     mOutputsAsStepModelInputs.size()
243     //
244     // and when (0 <= i < mOutputsAsStepModelInputs.size()),
245     //
246     //     mainModelOutputs[mOutputsAsStepModelInputsIndexToMainModel[i]] ==
247     //     mOutputsAsStepModelInputs[i].first
248     std::vector<uint32_t> mOutputsAsStepModelInputsIndexToMainModel;
249 
250     // The compilation caching token.
251     TokenHasher mToken;
252 };
253 
254 // An IF operation to be run on the ExecutionPlan::next() interpreter. The
255 // branch models might run on devices. See LogicalStep.
256 //
257 // Execution plan structure:
258 // Index  Step
259 //   i    if then=(i + 1) else=(j + 1)
260 //  ...   (then model steps)
261 //   j    goto k
262 //  ...   (else model steps)
263 //   k    (steps after the IF)
264 struct IfStep {
265     // The index of this step.
266     size_t index = ~size_t(0);
267     // The index of the first step of the "then" branch.
268     size_t thenStepIndex = ~size_t(0);
269     // The index of the first step of the "else" branch.
270     size_t elseStepIndex = ~size_t(0);
271     // The boolean condition input of the IF operation. The value of this
272     // operand determines the branch of the IF operation to be executed.
273     SourceOperandIndex conditionOperandIndex = {~uint32_t(0), ~uint32_t(0)};
274     // Input operands of the IF operation to be passed to a branch model.
275     std::vector<SourceOperandIndex> outerInputOperands;
276     // Output operands of the IF operation.
277     std::vector<SourceOperandIndex> outerOutputOperands;
278     // Input operands of the "then" branch model.
279     std::vector<SourceOperandIndex> thenBranchInputOperands;
280     // Output operands of the "then" branch model.
281     std::vector<SourceOperandIndex> thenBranchOutputOperands;
282     // Input operands of the "else" branch model.
283     std::vector<SourceOperandIndex> elseBranchInputOperands;
284     // Output operands of the "else" branch model.
285     std::vector<SourceOperandIndex> elseBranchOutputOperands;
286 };
287 
288 // A WHILE operation to be run on the ExecutionPlan::next() interpreter. The
289 // condition and body models might run other devices. See LogicalStep.
290 //
291 // Execution plan structure:
292 // Index  Step
293 //   i    while cond=(i + 1) body=(j + 1) exit=(k + 1)
294 //  ...   (cond model steps)
295 //   j    goto i
296 //  ...   (body model steps)
297 //   k    goto i
298 //  ...   (steps after the WHILE)
299 //
300 //  Note that WhileStep has WhileState associated with it.
301 struct WhileStep {
302     // The index of this step.
303     size_t index = ~size_t(0);
304     // The index of the first step of the condition model.
305     size_t condStepIndex = ~size_t(0);
306     // The index of the first step of the body model.
307     size_t bodyStepIndex = ~size_t(0);
308     // The index of the first step after the loop.
309     size_t exitStepIndex = ~size_t(0);
310     // Input operands of the WHILE operation to be passed to the condition and
311     // body models.
312     std::vector<SourceOperandIndex> outerInputOperands;
313     // Output operands of the WHILE operation.
314     std::vector<SourceOperandIndex> outerOutputOperands;
315     // Input operands of the condition model.
316     std::vector<SourceOperandIndex> condInputOperands;
317     // Output operand of the condition model. The value of this operand
318     // determines whether to continue execution or exit the loop.
319     SourceOperandIndex condOutputOperand = {~uint32_t(0), ~uint32_t(0)};
320     // Input operands of the body model.
321     std::vector<SourceOperandIndex> bodyInputOperands;
322     // Output operands of the body model.
323     std::vector<SourceOperandIndex> bodyOutputOperands;
324 };
325 
326 // A helper step. See LogicalStep.
327 struct GotoStep {
328     // The index of this step.
329     size_t index = ~size_t(0);
330     // The index of the step to go to.
331     size_t gotoStepIndex = ~size_t(0);
332 };
333 
334 // One of ExecutionStep, IfStep, WhileStep, or GotoStep.
335 //
336 // When ExecutionPlan::next() is called, it interprets logical steps until it
337 // encounters an ExecutionStep ("interpreted execution").
338 // - For an IfStep, it decides which branch to take and proceeds to the
339 //   corresponding step.
340 // - For a WhileStep, it decides whether to execute the condition or body (based
341 //   on WhileState), or exit the loop (based on the condition model output), and
342 //   proceeds to the corresponding step.
343 // - For a GotoStep, it proceeds to the indicated step unconditionally.
344 class LogicalStep {
345    public:
346     template <typename... Args>
LogicalStep(Args &&...args)347     explicit LogicalStep(Args&&... args) : mStep(std::forward<Args>(args)...) {}
348 
isExecution()349     bool isExecution() const { return std::holds_alternative<ExecutionStep>(mStep); }
isIf()350     bool isIf() const { return std::holds_alternative<IfStep>(mStep); }
isWhile()351     bool isWhile() const { return std::holds_alternative<WhileStep>(mStep); }
isGoto()352     bool isGoto() const { return std::holds_alternative<GotoStep>(mStep); }
353 
354     // Returns a non-null pointer or crashes.
executionStep()355     ExecutionStep* executionStep() { return &std::get<ExecutionStep>(mStep); }
ifStep()356     IfStep* ifStep() { return &std::get<IfStep>(mStep); }
whileStep()357     WhileStep* whileStep() { return &std::get<WhileStep>(mStep); }
gotoStep()358     GotoStep* gotoStep() { return &std::get<GotoStep>(mStep); }
359 
360     // Returns a non-null pointer or crashes.
executionStep()361     const ExecutionStep* executionStep() const { return &std::get<ExecutionStep>(mStep); }
ifStep()362     const IfStep* ifStep() const { return &std::get<IfStep>(mStep); }
whileStep()363     const WhileStep* whileStep() const { return &std::get<WhileStep>(mStep); }
gotoStep()364     const GotoStep* gotoStep() const { return &std::get<GotoStep>(mStep); }
365 
366     // May return nullptr.
tryExecutionStep()367     ExecutionStep* tryExecutionStep() { return std::get_if<ExecutionStep>(&mStep); }
tryIfStep()368     IfStep* tryIfStep() { return std::get_if<IfStep>(&mStep); }
tryWhileStep()369     WhileStep* tryWhileStep() { return std::get_if<WhileStep>(&mStep); }
tryGotoStep()370     GotoStep* tryGotoStep() { return std::get_if<GotoStep>(&mStep); }
371 
372     // May return nullptr.
tryExecutionStep()373     const ExecutionStep* tryExecutionStep() const { return std::get_if<ExecutionStep>(&mStep); }
tryIfStep()374     const IfStep* tryIfStep() const { return std::get_if<IfStep>(&mStep); }
tryWhileStep()375     const WhileStep* tryWhileStep() const { return std::get_if<WhileStep>(&mStep); }
tryGotoStep()376     const GotoStep* tryGotoStep() const { return std::get_if<GotoStep>(&mStep); }
377 
378     void dump() const;
379 
380    private:
381     std::variant<ExecutionStep, IfStep, WhileStep, GotoStep> mStep;
382 };
383 
384 std::string toString(const IfStep& step);
385 std::string toString(const WhileStep& step);
386 std::string toString(const GotoStep& step);
387 
388 // Describes the state of WhileStep.
389 struct WhileState {
390     // A pseudo iteration number indicating the loop is not being executed.
391     static constexpr uint64_t kOutsideLoop = ~uint64_t(0);
392     // Whether we need to evaluate the condition or body next.
393     enum Stage { EVALUATE_CONDITION, EVALUATE_BODY } stage = EVALUATE_CONDITION;
394     // Current iteration number. Must be set to kOutsideLoop when exiting the
395     // loop.
396     uint64_t iteration = kOutsideLoop;
397     // Time point when the loop started executing.
398     std::chrono::time_point<std::chrono::steady_clock> startTime;
399 };
400 
401 struct ConstantCopyLocation {
402     const uint8_t* buffer;
403     uint32_t length;
404 };
405 
406 struct ConstantReferenceLocation {
407     const Memory* memory;
408     uint32_t offset;
409     uint32_t length;
410 };
411 
412 class ExecutionPlan {
413    public:
414     ExecutionPlan(const ExecutionPlan&) = delete;
415     ExecutionPlan& operator=(const ExecutionPlan&) = delete;
416 
ExecutionPlan()417     ExecutionPlan() {}
~ExecutionPlan()418     ~ExecutionPlan() { delete mBody; }
419 
420     // Controller is part of the interface to a mechanism for performing an
421     // execution in N steps.
422     //
423     // The value of N may not be known beforehand if the model contains WHILE
424     // loops. See LogicalStep.
425     //
426     // Usage pattern:
427     // - Instantiate Controller with ExecutionPlan::makeController().
428     // - Call ExecutionPlan::next() on Controller N+1 times.  The first N times,
429     //   *executor is set to point to a new StepExecutor corresponding
430     //   to that step.  The N+1st time, *executor is set to nullptr,
431     //   signifying there are no more steps.
432     // - If ExecutionPlan::next() returns anything other than ANEURALNETWORKS_NO_ERROR,
433     //   a problem has occurred.
434     class Controller {
435         friend class ExecutionPlan;
436 
437        private:
438         Controller(const Controller&) = delete;
439         Controller& operator=(const Controller&) = delete;
440 
441         static const size_t kBadStepIndex = ~size_t(0);
442 
443         // A constructor for mState == SIMPLE.
444         Controller(const ExecutionPlan* plan, ExecutionBuilder* executionBuilder,
445                    const BurstBuilder* burstBuilder);
446         // A constructor for mState == COMPOUND.
447         Controller(const ExecutionPlan* plan, ExecutionBuilder* executionBuilder,
448                    const BurstBuilder* burstBuilder, uint32_t totalSizeOfTemporaries,
449                    std::map<SourceOperandIndex, uint32_t> sourceOperandToOffsetOfTemporary,
450                    std::map<SourceOperandIndex, uint32_t> sourceOperandToOffsetOfTemporary2,
451                    std::map<SourceOperandIndex, uint32_t> sourceOperandToInputIndex,
452                    std::map<SourceOperandIndex, uint32_t> sourceOperandToOutputIndex,
453                    const std::map<SourceOperandIndex, ConstantCopyLocation>&
454                            sourceOperandToConstantCopy,
455                    std::map<SourceOperandIndex, ConstantReferenceLocation>
456                            sourceOperandToConstantReference);
457 
458         // Sets the location of innerOperand to be the same as the location of outerOperand.
459         void setInput(const SourceOperandIndex& outerOperand,
460                       const SourceOperandIndex& innerOperand);
461         void setOutput(const SourceOperandIndex& outerOperand,
462                        const SourceOperandIndex& innerOperand);
463 
464         // Wait for mLastStepSyncFd to signal.
465         // No-op if mLastStepSyncFd is -1 which the mLastStepSyncFd is initialized to.
466         // mLastStepSyncFd will also be set to -1 when the most recently processed step
467         // does not generate a sync fence.
468         int waitForLastStepSyncFence() const;
469 
470         const ExecutionPlan* mPlan;
471         ExecutionBuilder* mExecutionBuilder;
472         const BurstBuilder* mBurstBuilder;
473         // Map from source operand index to an offset into mTemporaries used
474         // to represent that operand as an inter-partition input or output.
475         //
476         // The four maps
477         // - mSourceOperandToOffsetOfTemporary
478         // - mSourceOperandToInputIndex
479         // - mSourceOperandToOutputIndex
480         // - mSourceOperandToConstantReference
481         // are initialized from similarly named fields of ExecutionPlan::CompoundBody.
482         //
483         // A particular key appears in at most one map at any given time. This
484         // restriction does not apply to mSourceOperandToOffsetOfTemporary2.
485         //
486         // The maps are modified during the execution of IfStep and WhileStep.
487         // See ExecutionPlan::nextCompound().
488         std::map<SourceOperandIndex, uint32_t> mSourceOperandToOffsetOfTemporary;
489         // Map from source operand index to an additional offset into
490         // mTemporaries used for double buffering of WHILE loop output operands.
491         std::map<SourceOperandIndex, uint32_t> mSourceOperandToOffsetOfTemporary2;
492         // Map from source operand index to an input index of the main model.
493         std::map<SourceOperandIndex, uint32_t> mSourceOperandToInputIndex;
494         // Map from source operand index to an output index of the main model.
495         std::map<SourceOperandIndex, uint32_t> mSourceOperandToOutputIndex;
496         // Map from source operand index to a constant reference location.
497         // Used for WHILE loop operand initializers that are constant references.
498         std::map<SourceOperandIndex, ConstantReferenceLocation> mSourceOperandToConstantReference;
499         std::unique_ptr<MemoryAshmem> mTemporaries;
500         // Index of the next step to be processed by ExecutionPlan::next().
501         size_t mNextStepIndex;
502         // The value to reset mNextStepIndex to for partial CPU fallback.
503         size_t mFallbackNextStepIndex;
504         // Map from WhileStep index to the associated WhileState.
505         std::unordered_map<size_t, WhileState> mWhileState;
506         // The sync fence fd of the last step.
507         int mLastStepSyncFd;
508     };
509 
510     std::vector<std::shared_ptr<ExecutionBurstController>> makeBursts(int preference) const;
511 
512     std::shared_ptr<Controller> makeController(ExecutionBuilder* executionBuilder,
513                                                const BurstBuilder* burstBuilder) const;
514 
515     // Sets up a new StepExecutor and burstController (if applicable) if there
516     // is a step to execute. See ExecutionPlan::Controller.
517     // Handles control flow. See LogicalStep.
518     // syncFdOfLastStep is the sync fence fd generated by the most recently processed step.
519     int next(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor,
520              std::shared_ptr<ExecutionBurstController>* burstController = nullptr,
521              int syncFdOfLastStep = -1) const;
522 
523     // Create the same executor as the last one created by next().
524     int fallback(std::shared_ptr<Controller> controller,
525                  std::shared_ptr<StepExecutor>* executor) const;
526 
527     ExecutionStep* createNewExecutionStep(uint32_t sourceModelIndex,
528                                           const std::shared_ptr<Device> device);
529     IfStep* createNewIfStep();
530     WhileStep* createNewWhileStep();
531     GotoStep* createNewGotoStep();
532 
533     // Only legal to call when mState == COMPOUND.
getNextStepIndex()534     size_t getNextStepIndex() const { return compound()->mSteps.size(); }
535 
536     void becomeSingleStep(const std::shared_ptr<Device> device, const ModelBuilder* model);
537 
538     int finish(int32_t executionPreference, int32_t priority,
539                const std::optional<Deadline>& deadline);
540 
541     void recordTemporaryDef(SourceOperandIndex sourceOperandIndex, uint32_t stepIndex);
542 
543     void dump() const;
544 
545     void reset();
546 
isValid()547     bool isValid() const { return mState != EMPTY && mBody != nullptr && mBody->mSuccessfulFinish; }
isSimple()548     bool isSimple() const { return mState == SIMPLE; }
549     bool isSimpleCpu() const;
550 
setCaching(const std::string * cacheDir,const uint8_t * token)551     void setCaching(const std::string* cacheDir, const uint8_t* token) {
552         mCacheDir = cacheDir;
553         mToken = token;
554     }
getCacheDir()555     const std::string* getCacheDir() const { return mCacheDir; }
getCacheToken()556     const uint8_t* getCacheToken() const { return mToken; }
557 
558     // The caller is responsible for making sure the index is not out of range.
forEachStepRoleOfInput(uint32_t index,const StepRoleCallback & callback)559     void forEachStepRoleOfInput(uint32_t index, const StepRoleCallback& callback) const {
560         CHECK(mBody != nullptr);
561         mBody->forEachStepRoleOfInput(index, callback);
562     }
forEachStepRoleOfOutput(uint32_t index,const StepRoleCallback & callback)563     void forEachStepRoleOfOutput(uint32_t index, const StepRoleCallback& callback) const {
564         CHECK(mBody != nullptr);
565         mBody->forEachStepRoleOfOutput(index, callback);
566     }
567 
getSourceModels()568     SourceModels& getSourceModels() { return mSourceModels; }
getSourceModels()569     const SourceModels& getSourceModels() const { return mSourceModels; }
570 
571     // These functions are solely intended for use by unit tests of
572     // the partitioning algorithm.
573     enum class Kind {
574         ERROR,
575         EMPTY,
576         SIMPLE,
577         COMPOUND
578     };  // See operator<< defined outside this class
579     Kind forTest_getKind() const;
580     std::shared_ptr<const Device> forTest_simpleGetDevice() const;
581     const std::vector<std::shared_ptr<LogicalStep>>& forTest_compoundGetSteps() const;
582     bool forTest_hasStepModelOutputsOfUnknownSize() const;
583     const uint8_t* forTest_simpleGetCacheToken() const;
584 
585    private:
586     // Becomes a new COMPOUND step if mState == EMPTY, otherwise does nothing.
587     // Illegal to call for when mState == SIMPLE.
588     void becomeCompoundIfEmpty();
589     void findTempsAsStepModelOutputs();
590 
591     class Buffer {
592        public:
593         Buffer(void* pointer, uint32_t size);
594         Buffer(RunTimePoolInfo info, uint32_t offset);
595         void* getPointer() const;
596         uint32_t getSize() const;
597         void flush() const;
598 
599        private:
600         RunTimePoolInfo mInfo;
601         uint32_t mOffset;
602     };
603 
604     // Returns the buffer associated with a partition boundary operand.
605     std::optional<Buffer> getBuffer(std::shared_ptr<Controller> controller,
606                                     SourceOperandIndex operandIndex) const;
607     std::optional<Buffer> getBufferFromModelArgumentInfo(
608             const ModelArgumentInfo& info, const ExecutionBuilder* executionBuilder) const;
609     // Reads the value of a partition boundary boolean condition operand.
610     int readConditionValue(std::shared_ptr<Controller> controller, SourceOperandIndex operandIndex,
611                            bool* value) const;
612 
613     // Handles control flow. See LogicalStep.
614     int nextCompound(std::shared_ptr<Controller> controller,
615                      std::shared_ptr<StepExecutor>* executor,
616                      std::shared_ptr<ExecutionBurstController>* burstController) const;
617     int nextCompound(const ExecutionStep* step, std::shared_ptr<Controller> controller,
618                      std::shared_ptr<StepExecutor>* executor,
619                      std::shared_ptr<ExecutionBurstController>* burstController) const;
620     int nextCompound(const IfStep* step, std::shared_ptr<Controller> controller,
621                      std::shared_ptr<StepExecutor>* executor,
622                      std::shared_ptr<ExecutionBurstController>* burstController) const;
623     int nextCompound(const WhileStep* step, std::shared_ptr<Controller> controller,
624                      std::shared_ptr<StepExecutor>* executor,
625                      std::shared_ptr<ExecutionBurstController>* burstController) const;
626     int nextCompound(const GotoStep* step, std::shared_ptr<Controller> controller,
627                      std::shared_ptr<StepExecutor>* executor,
628                      std::shared_ptr<ExecutionBurstController>* burstController) const;
629 
630     struct Body {
~BodyBody631         virtual ~Body() {}
632         virtual void dump() const = 0;
633         virtual int finish(const SourceModels* sourceModels, int32_t executionPreference,
634                            int32_t priority, const std::optional<Deadline>& deadline) = 0;
635         virtual bool hasStepModelOutputsOfUnknownSize() const = 0;
636         virtual void forEachStepRoleOfInput(uint32_t index,
637                                             const StepRoleCallback& callback) const = 0;
638         virtual void forEachStepRoleOfOutput(uint32_t index,
639                                              const StepRoleCallback& callback) const = 0;
640         bool mSuccessfulFinish = false;
641     };
642 
643     struct SimpleBody : Body {
SimpleBodySimpleBody644         SimpleBody(std::shared_ptr<Device> device, const ModelBuilder* model,
645                    const std::string* cacheDir, const uint8_t* token)
646             : mDevice(device), mModel(model), mCacheDir(cacheDir), mToken(token) {}
647 
648         void dump() const override;
649         int finish(const SourceModels* sourceModels, int32_t executionPreference, int32_t priority,
650                    const std::optional<Deadline>& deadline) override;
hasStepModelOutputsOfUnknownSizeSimpleBody651         bool hasStepModelOutputsOfUnknownSize() const override { return false; }
652         void forEachStepRoleOfInput(uint32_t index,
653                                     const StepRoleCallback& callback) const override;
654         void forEachStepRoleOfOutput(uint32_t index,
655                                      const StepRoleCallback& callback) const override;
656 
657         std::shared_ptr<Device> mDevice;
658         const ModelBuilder* mModel;
659         std::shared_ptr<PreparedModel> mPreparedModel;
660 
661         const std::string* mCacheDir;
662         TokenHasher mToken;
663     };
664 
665     struct CompoundBody : Body {
666         void dump() const override;
667         int finish(const SourceModels* sourceModels, int32_t executionPreference, int32_t priority,
668                    const std::optional<Deadline>& deadline) override;
hasStepModelOutputsOfUnknownSizeCompoundBody669         bool hasStepModelOutputsOfUnknownSize() const override {
670             return mHasStepModelOutputOfUnknownSize;
671         }
672         void forEachStepRoleOfInput(uint32_t index,
673                                     const StepRoleCallback& callback) const override;
674         void forEachStepRoleOfOutput(uint32_t index,
675                                      const StepRoleCallback& callback) const override;
676 
677         // TODO: Some of the data is working state information that
678         // shouldn't be needed after we've constructed but not
679         // executed the plan.
680 
681         std::vector<std::shared_ptr<LogicalStep>> mSteps;
682 
683         // Map from source operand index to defining ExecutionStep index.
684         // Used for all (and only) TEMPORARY_VARIABLEs that are defined by
685         // ExecutionSteps. Those defined by IfSteps and WhileSteps are not in
686         // the map.
687         std::map<SourceOperandIndex, uint32_t> mTemporaryToDefiningExecutionStep;
688 
689         // Map from source operand index to input index of the main model.
690         // This map only contains SUBGRAPH_INPUTs of the main model and is used
691         // to initialize ExecutionPlan::Controller::mSourceOperandToInputIndex;
692         std::map<SourceOperandIndex, uint32_t> mSourceOperandToInputIndex;
693 
694         // Map from source operand index to output index of the main model.
695         // This map only contains SUBGRAPH_OUTPUTs of the main model and is used
696         // to initialize ExecutionPlan::Controller::mSourceOperandToOutputIndex;
697         std::map<SourceOperandIndex, uint32_t> mSourceOperandToOutputIndex;
698 
699         // Map from source operand index to location of a CONSTANT_COPY operand.
700         // This map only contains constant partition boundary IF and WHILE
701         // operands and is used to create a ExecutionPlan::Controller.
702         std::map<SourceOperandIndex, ConstantCopyLocation> mSourceOperandToBoundaryConstantCopy;
703 
704         // Map from source operand index to location of a CONSTANT_REFERENCE
705         // operand.  This map only contains constant partition boundary IF and
706         // WHILE operands and is used to initialize
707         // ExecutionPlan::Controller::mSourceOperandToConstantReference.
708         std::map<SourceOperandIndex, ConstantReferenceLocation>
709                 mSourceOperandToBoundaryConstantReference;
710 
711         bool mHasStepModelOutputOfUnknownSize = false;
712 
713        private:
714         void findTempsAsStepModelOutputs();
715 
716         // Constant values that are inputs to IF and WHILE operations and lie on
717         // a partition boundary ("control flow boundary constants") require
718         // special treatment. We need to be able to dynamically associate those
719         // values with the corresponding SUBGRAPH_INPUT operands in a referenced
720         // model.
721         //
722         // For CONSTANT_COPY boundary operands, we copy those to temporary
723         // memory and treat them similarly to TEMPORARY_VARIABLE operands in
724         // Controller.
725         //
726         // For CONSTANT_REFERENCE boundary operands, we keep track of them in
727         // ExecutionPlan::Controller::mSourceOperandToConstantReference.
728         //
729         // Note that for IF inputs and input-only WHILE inputs that are boundary
730         // constants, we could embed those inside the referenced model, but we
731         // currently don't do so. See b/148216514.
732         void findControlFlowBoundaryConstants(const SourceModels* sourceModels);
733     };
734 
735     enum { EMPTY, SIMPLE, COMPOUND } mState = EMPTY;
736     Body* mBody = nullptr;
simple()737     SimpleBody* simple() {
738         CHECK(mState == SIMPLE);
739         CHECK(mBody != nullptr);
740         return static_cast<SimpleBody*>(mBody);
741     }
simple()742     const SimpleBody* simple() const {
743         CHECK(mState == SIMPLE);
744         CHECK(mBody != nullptr);
745         return static_cast<const SimpleBody*>(mBody);
746     }
compound()747     CompoundBody* compound() {
748         CHECK(mState == COMPOUND);
749         CHECK(mBody != nullptr);
750         return static_cast<CompoundBody*>(mBody);
751     }
compound()752     const CompoundBody* compound() const {
753         CHECK(mState == COMPOUND);
754         CHECK(mBody != nullptr);
755         return static_cast<const CompoundBody*>(mBody);
756     }
757 
758     // Pointers to compilation caching information in CompilationBuilder.
759     const std::string* mCacheDir = nullptr;
760     const uint8_t* mToken = nullptr;
761     SourceModels mSourceModels;
762 };
763 
764 inline std::ostream& operator<<(std::ostream& out, ExecutionPlan::Kind kind) {
765     const int intKind = static_cast<int>(kind);
766     if (kind < ExecutionPlan::Kind::ERROR || kind > ExecutionPlan::Kind::COMPOUND) {
767         return out << "<UNK(" << intKind << ")>";
768     }
769     static const char* name[] = {"ERROR", "EMPTY", "SIMPLE", "COMPOUND"};
770     return out << name[intKind];
771 }
772 
773 }  // namespace nn
774 }  // namespace android
775 
776 #endif  // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_EXECUTION_PLAN_H
777