1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Utils"
18 
19 #include "Utils.h"
20 
21 #include <android-base/logging.h>
22 #include <android-base/properties.h>
23 #include <android-base/strings.h>
24 #include <errno.h>
25 #include <poll.h>
26 
27 #include <algorithm>
28 #include <cfloat>
29 #include <functional>
30 #include <iostream>
31 #include <limits>
32 #include <numeric>
33 #include <set>
34 #include <string>
35 #include <tuple>
36 #include <unordered_map>
37 #include <utility>
38 #include <vector>
39 
40 #include "ControlFlow.h"
41 #include "NeuralNetworks.h"
42 #include "NeuralNetworksOEM.h"
43 #include "OperationResolver.h"
44 #include "ValidateHal.h"
45 
46 namespace android {
47 namespace nn {
48 
49 using namespace hal;
50 
51 constexpr PerformanceInfo kNoPerformanceInfo = {.execTime = FLT_MAX, .powerUsage = FLT_MAX};
52 
53 const char kVLogPropKey[] = "debug.nn.vlog";
54 int vLogMask = ~0;
55 
56 // Split the space separated list of tags from verbose log setting and build the
57 // logging mask from it. note that '1' and 'all' are special cases to enable all
58 // verbose logging.
59 //
60 // NN API verbose logging setting comes from system property debug.nn.vlog.
61 // Example:
62 // setprop debug.nn.vlog 1 : enable all logging tags.
63 // setprop debug.nn.vlog "model compilation" : only enable logging for MODEL and
64 //                                             COMPILATION tags.
initVLogMask()65 void initVLogMask() {
66     vLogMask = 0;
67     const std::string vLogSetting = android::base::GetProperty(kVLogPropKey, "");
68     if (vLogSetting.empty()) {
69         return;
70     }
71 
72     std::unordered_map<std::string, int> vLogFlags = {{"1", -1},
73                                                       {"all", -1},
74                                                       {"model", MODEL},
75                                                       {"compilation", COMPILATION},
76                                                       {"execution", EXECUTION},
77                                                       {"cpuexe", CPUEXE},
78                                                       {"manager", MANAGER},
79                                                       {"driver", DRIVER},
80                                                       {"memory", MEMORY}};
81 
82     std::vector<std::string> elements = android::base::Split(vLogSetting, " ,:");
83     for (const auto& elem : elements) {
84         const auto& flag = vLogFlags.find(elem);
85         if (flag == vLogFlags.end()) {
86             LOG(ERROR) << "Unknown trace flag: " << elem;
87             continue;
88         }
89 
90         if (flag->second == -1) {
91             // -1 is used for the special values "1" and "all" that enable all
92             // tracing.
93             vLogMask = ~0;
94             return;
95         } else {
96             vLogMask |= 1 << flag->second;
97         }
98     }
99 }
100 
makeDeadline(uint64_t duration)101 Deadline makeDeadline(uint64_t duration) {
102     const auto maxTime = Deadline::max();
103     const auto currentTime = std::chrono::steady_clock::now();
104 
105     // Create Deadline. If there would be an overflow, use the max value.
106     const uint64_t remainingNanoseconds =
107             std::chrono::duration_cast<std::chrono::nanoseconds>(maxTime - currentTime).count();
108     if (duration > remainingNanoseconds) {
109         return maxTime;
110     }
111     return currentTime + std::chrono::nanoseconds{duration};
112 }
113 
makeDeadline(std::optional<uint64_t> duration)114 std::optional<Deadline> makeDeadline(std::optional<uint64_t> duration) {
115     return duration.has_value() ? makeDeadline(*duration) : std::optional<Deadline>{};
116 }
117 
getMaxNanosecondsSinceEpoch()118 static uint64_t getMaxNanosecondsSinceEpoch() {
119     const auto maxTime =
120             std::chrono::time_point<std::chrono::steady_clock, std::chrono::nanoseconds>::max();
121     return maxTime.time_since_epoch().count();
122 }
123 
makeDeadline(const OptionalTimePoint & timePoint)124 std::optional<Deadline> makeDeadline(const OptionalTimePoint& timePoint) {
125     using Discriminator = hal::OptionalTimePoint::hidl_discriminator;
126     if (timePoint.getDiscriminator() == Discriminator::none) {
127         return std::nullopt;
128     }
129     const uint64_t nanosecondsSinceEpoch = timePoint.nanosecondsSinceEpoch();
130     const uint64_t maxNanosecondsSinceEpoch = getMaxNanosecondsSinceEpoch();
131 
132     // Clamp time point to max.
133     if (nanosecondsSinceEpoch >= maxNanosecondsSinceEpoch) {
134         return Deadline::max();
135     }
136 
137     // Return provided time point.
138     return Deadline{std::chrono::nanoseconds{nanosecondsSinceEpoch}};
139 }
140 
hasDeadlinePassed(const std::optional<Deadline> & deadline)141 bool hasDeadlinePassed(const std::optional<Deadline>& deadline) {
142     if (!deadline.has_value()) {
143         return false;
144     }
145     return std::chrono::steady_clock::now() >= *deadline;
146 }
147 
makeTimePoint(const Deadline & deadline)148 static OptionalTimePoint makeTimePoint(const Deadline& deadline) {
149     const auto timeSinceEpoch = deadline.time_since_epoch();
150     const uint64_t nanosecondsSinceEpoch =
151             std::chrono::duration_cast<std::chrono::nanoseconds>(timeSinceEpoch).count();
152     OptionalTimePoint ret;
153     ret.nanosecondsSinceEpoch(nanosecondsSinceEpoch);
154     return ret;
155 }
156 
makeTimePoint(const std::optional<Deadline> & deadline)157 OptionalTimePoint makeTimePoint(const std::optional<Deadline>& deadline) {
158     return deadline.has_value() ? makeTimePoint(*deadline) : OptionalTimePoint{};
159 }
160 
isExtensionOperandType(int32_t type)161 static bool isExtensionOperandType(int32_t type) {
162     return static_cast<uint32_t>(type) > static_cast<uint32_t>(OperandTypeRange::BASE_MAX);
163 }
164 
isExtensionOperationType(ANeuralNetworksOperationType type)165 static bool isExtensionOperationType(ANeuralNetworksOperationType type) {
166     return static_cast<uint32_t>(type) > static_cast<uint32_t>(OperationTypeRange::BASE_MAX);
167 }
168 
isExtensionOperandType(OperandType type)169 bool isExtensionOperandType(OperandType type) {
170     return isExtensionOperandType(static_cast<int32_t>(type));
171 }
172 
isExtensionOperationType(OperationType type)173 bool isExtensionOperationType(OperationType type) {
174     return isExtensionOperationType(static_cast<int32_t>(type));
175 }
176 
177 namespace {
178 
179 template <typename EntryType, uint32_t entryCount, uint32_t entryCountOEM>
tableLookup(const EntryType (& table)[entryCount],const EntryType (& tableOEM)[entryCountOEM],uint32_t code)180 EntryType tableLookup(const EntryType (&table)[entryCount],
181                       const EntryType (&tableOEM)[entryCountOEM], uint32_t code) {
182     if (code < entryCount) {
183         return table[code];
184     } else if (code >= kOEMCodeBase && (code - kOEMCodeBase) < entryCountOEM) {
185         return tableOEM[code - kOEMCodeBase];
186     } else {
187         nnAssert(!"tableLookup: bad code");
188         return EntryType();
189     }
190 }
191 
192 class OperationValidationContext : public IOperationValidationContext {
193     DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext);
194 
195    public:
OperationValidationContext(const char * operationName,uint32_t inputCount,const uint32_t * inputIndexes,uint32_t outputCount,const uint32_t * outputIndexes,const Operand * operands,HalVersion halVersion)196     OperationValidationContext(const char* operationName, uint32_t inputCount,
197                                const uint32_t* inputIndexes, uint32_t outputCount,
198                                const uint32_t* outputIndexes, const Operand* operands,
199                                HalVersion halVersion)
200         : operationName(operationName),
201           inputCount(inputCount),
202           inputIndexes(inputIndexes),
203           outputCount(outputCount),
204           outputIndexes(outputIndexes),
205           operands(operands),
206           halVersion(halVersion) {}
207 
208     const char* getOperationName() const override;
209     HalVersion getHalVersion() const override;
210 
211     uint32_t getNumInputs() const override;
212     OperandType getInputType(uint32_t index) const override;
213     Shape getInputShape(uint32_t index) const override;
214     const OperandExtraParams getInputExtraParams(uint32_t index) const override;
215 
216     uint32_t getNumOutputs() const override;
217     OperandType getOutputType(uint32_t index) const override;
218     Shape getOutputShape(uint32_t index) const override;
219 
220    private:
221     const Operand* getInputOperand(uint32_t index) const;
222     const Operand* getOutputOperand(uint32_t index) const;
223 
224     const char* operationName;
225     uint32_t inputCount;
226     const uint32_t* inputIndexes;
227     uint32_t outputCount;
228     const uint32_t* outputIndexes;
229     const Operand* operands;
230     HalVersion halVersion;
231 };
232 
getOperationName() const233 const char* OperationValidationContext::getOperationName() const {
234     return operationName;
235 }
236 
getHalVersion() const237 HalVersion OperationValidationContext::getHalVersion() const {
238     return halVersion;
239 }
240 
getInputOperand(uint32_t index) const241 const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
242     CHECK(index < static_cast<uint32_t>(inputCount));
243     return &operands[inputIndexes[index]];
244 }
245 
getOutputOperand(uint32_t index) const246 const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const {
247     CHECK(index < static_cast<uint32_t>(outputCount));
248     return &operands[outputIndexes[index]];
249 }
250 
getNumInputs() const251 uint32_t OperationValidationContext::getNumInputs() const {
252     return inputCount;
253 }
254 
getNumOutputs() const255 uint32_t OperationValidationContext::getNumOutputs() const {
256     return outputCount;
257 }
258 
getInputType(uint32_t index) const259 OperandType OperationValidationContext::getInputType(uint32_t index) const {
260     return getInputOperand(index)->type;
261 }
262 
getInputShape(uint32_t index) const263 Shape OperationValidationContext::getInputShape(uint32_t index) const {
264     const Operand* operand = getInputOperand(index);
265     return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
266             operand->extraParams};
267 }
268 
getInputExtraParams(uint32_t index) const269 const OperandExtraParams OperationValidationContext::getInputExtraParams(uint32_t index) const {
270     return getInputOperand(index)->extraParams;
271 }
272 
getOutputType(uint32_t index) const273 OperandType OperationValidationContext::getOutputType(uint32_t index) const {
274     return getOutputOperand(index)->type;
275 }
276 
getOutputShape(uint32_t index) const277 Shape OperationValidationContext::getOutputShape(uint32_t index) const {
278     const Operand* operand = getOutputOperand(index);
279     return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
280             operand->extraParams};
281 }
282 
283 };  // anonymous namespace
284 
285 #define COUNT(X) (sizeof(X) / sizeof(X[0]))
286 
getOperandTypeName(OperandType type)287 std::string getOperandTypeName(OperandType type) {
288     return toString(type);
289 }
290 
getOperationName(uint32_t code)291 static std::string getOperationName(uint32_t code) {
292     return getOperationName(static_cast<OperationType>(code));
293 }
294 
getOperationName(OperationType type)295 std::string getOperationName(OperationType type) {
296     return toString(type);
297 }
298 
299 const uint32_t kSizeOfDataType[]{
300         4,  // ANEURALNETWORKS_FLOAT32
301         4,  // ANEURALNETWORKS_INT32
302         4,  // ANEURALNETWORKS_UINT32
303         4,  // ANEURALNETWORKS_TENSOR_FLOAT32
304         4,  // ANEURALNETWORKS_TENSOR_INT32
305         1,  // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
306         1,  // ANEURALNETWORKS_BOOL
307         2,  // ANEURALNETWORKS_TENSOR_QUANT16_SYMM
308         2,  // ANEURALNETWORKS_TENSOR_FLOAT16
309         1,  // ANEURALNETWORKS_TENSOR_BOOL8
310         2,  // ANEURALNETWORKS_FLOAT16
311         1,  // ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL
312         2,  // ANEURALNETWORKS_TENSOR_QUANT16_ASYMM
313         1,  // ANEURALNETWORKS_TENSOR_QUANT8_SYMM
314         1,  // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED
315         0,  // ANEURALNETWORKS_MODEL
316 };
317 
318 static_assert(COUNT(kSizeOfDataType) == kNumberOfDataTypes, "kSizeOfDataType is incorrect");
319 
320 const bool kScalarDataType[]{
321         true,   // ANEURALNETWORKS_FLOAT32
322         true,   // ANEURALNETWORKS_INT32
323         true,   // ANEURALNETWORKS_UINT32
324         false,  // ANEURALNETWORKS_TENSOR_FLOAT32
325         false,  // ANEURALNETWORKS_TENSOR_INT32
326         false,  // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
327         true,   // ANEURALNETWORKS_BOOL
328         false,  // ANEURALNETWORKS_TENSOR_QUANT16_SYMM
329         false,  // ANEURALNETWORKS_TENSOR_FLOAT16
330         false,  // ANEURALNETWORKS_TENSOR_BOOL8
331         true,   // ANEURALNETWORKS_FLOAT16
332         false,  // ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL
333         false,  // ANEURALNETWORKS_TENSOR_QUANT16_ASYMM
334         false,  // ANEURALNETWORKS_TENSOR_QUANT8_SYMM
335         false,  // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED
336         true,   // ANEURALNETWORKS_MODEL
337 };
338 
339 static_assert(COUNT(kScalarDataType) == kNumberOfDataTypes, "kScalarDataType is incorrect");
340 
341 const uint32_t kSizeOfDataTypeOEM[]{
342         0,  // ANEURALNETWORKS_OEM
343         1,  // ANEURALNETWORKS_TENSOR_OEM_BYTE
344 };
345 
346 static_assert(COUNT(kSizeOfDataTypeOEM) == kNumberOfDataTypesOEM,
347               "kSizeOfDataTypeOEM is incorrect");
348 
349 const bool kScalarDataTypeOEM[]{
350         true,   // ANEURALNETWORKS_OEM
351         false,  // ANEURALNETWORKS_TENSOR_OEM_BYTE
352 };
353 
354 static_assert(COUNT(kScalarDataTypeOEM) == kNumberOfDataTypesOEM,
355               "kScalarDataTypeOEM is incorrect");
356 
nonExtensionOperandTypeIsScalar(int type)357 bool nonExtensionOperandTypeIsScalar(int type) {
358     CHECK(!isExtensionOperandType(type)) << "Extension operand types are not supported";
359     return tableLookup(kScalarDataType, kScalarDataTypeOEM, type);
360 }
361 
nonExtensionOperandSizeOfData(OperandType type,const std::vector<uint32_t> & dimensions)362 uint32_t nonExtensionOperandSizeOfData(OperandType type, const std::vector<uint32_t>& dimensions) {
363     CHECK(!isExtensionOperandType(type)) << "Size of extension operand data is unknown";
364     int n = static_cast<int>(type);
365     uint32_t sizeOfElement = tableLookup(kSizeOfDataType, kSizeOfDataTypeOEM, n);
366     return tableLookup(kScalarDataType, kScalarDataTypeOEM, n)
367                    ? sizeOfElement
368                    : sizeOfTensorData(sizeOfElement, dimensions);
369 }
370 
371 // Returns a pair of {false, size} on success, {true, 0} if size overflows uint32_t.
sizeOfTensorDataHelper(uint32_t sizeOfElement,const std::vector<uint32_t> & dimensions)372 static std::pair<bool, uint32_t> sizeOfTensorDataHelper(uint32_t sizeOfElement,
373                                                         const std::vector<uint32_t>& dimensions) {
374     if (dimensions.empty()) {
375         return {false, 0};
376     }
377     uint64_t size = static_cast<uint64_t>(sizeOfElement);
378     constexpr uint64_t kMaxSize = static_cast<uint64_t>(std::numeric_limits<uint32_t>::max());
379     for (uint32_t d : dimensions) {
380         size *= d;
381         if (size > kMaxSize) return {true, 0};
382     }
383     return {false, static_cast<uint32_t>(size)};
384 }
385 
sizeOfTensorData(uint32_t sizeOfElement,const std::vector<uint32_t> & dimensions)386 uint32_t sizeOfTensorData(uint32_t sizeOfElement, const std::vector<uint32_t>& dimensions) {
387     const auto [overflow, size] = sizeOfTensorDataHelper(sizeOfElement, dimensions);
388     CHECK(!overflow);
389     return size;
390 }
391 
nonExtensionOperandSizeOfDataOverflowsUInt32(hal::OperandType type,const std::vector<uint32_t> & dimensions)392 bool nonExtensionOperandSizeOfDataOverflowsUInt32(hal::OperandType type,
393                                                   const std::vector<uint32_t>& dimensions) {
394     CHECK(!isExtensionOperandType(type)) << "Size of extension operand data is unknown";
395     int n = static_cast<int>(type);
396     uint32_t sizeOfElement = tableLookup(kSizeOfDataType, kSizeOfDataTypeOEM, n);
397     return tableLookup(kScalarDataType, kScalarDataTypeOEM, n)
398                    ? false
399                    : sizeOfTensorDataOverflowsUInt32(sizeOfElement, dimensions);
400 }
401 
sizeOfTensorDataOverflowsUInt32(uint32_t sizeOfElement,const std::vector<uint32_t> & dimensions)402 bool sizeOfTensorDataOverflowsUInt32(uint32_t sizeOfElement,
403                                      const std::vector<uint32_t>& dimensions) {
404     return sizeOfTensorDataHelper(sizeOfElement, dimensions).first;
405 }
406 
tensorHasUnspecifiedDimensions(int type,const uint32_t * dim,uint32_t dimCount)407 bool tensorHasUnspecifiedDimensions(int type, const uint32_t* dim, uint32_t dimCount) {
408     if (!isExtensionOperandType(type)) {
409         CHECK(!nonExtensionOperandTypeIsScalar(type))
410                 << "A scalar type can never have unspecified dimensions";
411     }
412     return dimCount == 0 || std::find(dim, dim + dimCount, 0) != (dim + dimCount);
413 }
414 
tensorHasUnspecifiedDimensions(OperandType type,const std::vector<uint32_t> & dimensions)415 bool tensorHasUnspecifiedDimensions(OperandType type, const std::vector<uint32_t>& dimensions) {
416     return tensorHasUnspecifiedDimensions(static_cast<int>(type), dimensions.data(),
417                                           dimensions.size());
418 }
419 
tensorHasUnspecifiedDimensions(const ANeuralNetworksOperandType * type)420 bool tensorHasUnspecifiedDimensions(const ANeuralNetworksOperandType* type) {
421     return tensorHasUnspecifiedDimensions(type->type, type->dimensions, type->dimensionCount);
422 }
423 
tensorHasUnspecifiedDimensions(const Operand & operand)424 bool tensorHasUnspecifiedDimensions(const Operand& operand) {
425     return tensorHasUnspecifiedDimensions(static_cast<int>(operand.type), operand.dimensions.data(),
426                                           operand.dimensions.size());
427 }
428 
alignBytesNeeded(uint32_t index,size_t length)429 uint32_t alignBytesNeeded(uint32_t index, size_t length) {
430     uint32_t pattern;
431     if (length < 2) {
432         pattern = 0;  // No alignment necessary
433     } else if (length < 4) {
434         pattern = 1;  // Align on 2-byte boundary
435     } else {
436         pattern = 3;  // Align on 4-byte boundary
437     }
438     uint32_t extra = (~(index - 1)) & pattern;
439     return extra;
440 }
441 
logModelToInfo(const V1_0::Model & model)442 void logModelToInfo(const V1_0::Model& model) {
443     LOG(INFO) << "V1_0::Model start";
444     LOG(INFO) << "operands" << toString(model.operands);
445     LOG(INFO) << "operations" << toString(model.operations);
446     LOG(INFO) << "inputIndexes" << toString(model.inputIndexes);
447     LOG(INFO) << "outputIndexes" << toString(model.outputIndexes);
448     LOG(INFO) << "operandValues size" << model.operandValues.size();
449     LOG(INFO) << "pools" << SHOW_IF_DEBUG(toString(model.pools));
450 }
451 
logModelToInfo(const V1_1::Model & model)452 void logModelToInfo(const V1_1::Model& model) {
453     LOG(INFO) << "V1_1::Model start";
454     LOG(INFO) << "operands" << toString(model.operands);
455     LOG(INFO) << "operations" << toString(model.operations);
456     LOG(INFO) << "inputIndexes" << toString(model.inputIndexes);
457     LOG(INFO) << "outputIndexes" << toString(model.outputIndexes);
458     LOG(INFO) << "operandValues size " << model.operandValues.size();
459     LOG(INFO) << "pools" << SHOW_IF_DEBUG(toString(model.pools));
460 }
461 
logModelToInfo(const V1_2::Model & model)462 void logModelToInfo(const V1_2::Model& model) {
463     LOG(INFO) << "V1_2::Model start";
464     LOG(INFO) << "operands" << toString(model.operands);
465     LOG(INFO) << "operations" << toString(model.operations);
466     LOG(INFO) << "inputIndexes" << toString(model.inputIndexes);
467     LOG(INFO) << "outputIndexes" << toString(model.outputIndexes);
468     LOG(INFO) << "operandValues size" << model.operandValues.size();
469     LOG(INFO) << "pools" << SHOW_IF_DEBUG(toString(model.pools));
470     LOG(INFO) << "relaxComputationFloat32toFloat16" << model.relaxComputationFloat32toFloat16;
471     LOG(INFO) << "extensionNameToPrefix" << toString(model.extensionNameToPrefix);
472 }
473 
logSubgraphToInfo(std::string label,const V1_3::Subgraph & subgraph)474 static void logSubgraphToInfo(std::string label, const V1_3::Subgraph& subgraph) {
475     LOG(INFO) << label << ".operands" << toString(subgraph.operands);
476     LOG(INFO) << label << ".operations" << toString(subgraph.operations);
477     LOG(INFO) << label << ".inputIndexes" << toString(subgraph.inputIndexes);
478     LOG(INFO) << label << ".outputIndexes" << toString(subgraph.outputIndexes);
479 }
480 
logModelToInfo(const V1_3::Model & model)481 void logModelToInfo(const V1_3::Model& model) {
482     LOG(INFO) << "V1_3::Model start";
483     logSubgraphToInfo("main", model.main);
484     for (uint32_t i = 0, n = model.referenced.size(); i < n; ++i) {
485         logSubgraphToInfo("referenced[" + std::to_string(i) + "]", model.referenced[i]);
486     }
487     LOG(INFO) << "operandValues size " << model.operandValues.size();
488     LOG(INFO) << "pools" << SHOW_IF_DEBUG(toString(model.pools));
489     LOG(INFO) << "relaxComputationFloat32toFloat16 " << model.relaxComputationFloat32toFloat16;
490     LOG(INFO) << "extensionNameToPrefix" << toString(model.extensionNameToPrefix);
491 }
492 
validateOperandSymmPerChannelQuantParams(const Operand & halOperand,const ANeuralNetworksSymmPerChannelQuantParams & channelQuant,const char * tag)493 bool validateOperandSymmPerChannelQuantParams(
494         const Operand& halOperand, const ANeuralNetworksSymmPerChannelQuantParams& channelQuant,
495         const char* tag) {
496     if (halOperand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
497         return false;
498     }
499 
500     NN_RET_CHECK_LT(channelQuant.channelDim, halOperand.dimensions.size()) << tag;
501     NN_RET_CHECK(channelQuant.scales != nullptr) << tag;
502     NN_RET_CHECK_EQ(channelQuant.scaleCount, halOperand.dimensions[channelQuant.channelDim]) << tag;
503     NN_RET_CHECK_NE(halOperand.dimensions[channelQuant.channelDim], 0u)
504             << tag << " channel dimension " << channelQuant.channelDim << " is underspecified";
505     for (uint32_t i = 0; i < halOperand.dimensions[channelQuant.channelDim]; i++) {
506         NN_RET_CHECK_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]";
507     }
508     return true;
509 }
510 
validateScalarDimensions(const ANeuralNetworksOperandType & type,const char * tag)511 static bool validateScalarDimensions(const ANeuralNetworksOperandType& type, const char* tag) {
512     NN_RET_CHECK_EQ(type.dimensionCount, 0u) << tag << " invalid dimensions for scalar type";
513     NN_RET_CHECK(type.dimensions == nullptr) << tag << " invalid dimensions for scalar type";
514     return true;
515 }
516 
validateQuant8AsymmParams(const ANeuralNetworksOperandType & type,const char * tag)517 static bool validateQuant8AsymmParams(const ANeuralNetworksOperandType& type, const char* tag) {
518     NN_RET_CHECK(0 <= type.zeroPoint && type.zeroPoint <= 255)
519             << tag << " invalid zeroPoint: " << type.zeroPoint;
520     NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
521     return true;
522 }
523 
validateQuant8AsymmSignedParams(const ANeuralNetworksOperandType & type,const char * tag)524 static bool validateQuant8AsymmSignedParams(const ANeuralNetworksOperandType& type,
525                                             const char* tag) {
526     NN_RET_CHECK(-128 <= type.zeroPoint && type.zeroPoint <= 127)
527             << tag << " invalid zeroPoint: " << type.zeroPoint;
528     NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
529     return true;
530 }
531 
validateQuant8SymmParams(const ANeuralNetworksOperandType & type,const char * tag)532 static bool validateQuant8SymmParams(const ANeuralNetworksOperandType& type, const char* tag) {
533     NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint;
534     NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
535     return true;
536 }
537 
validateQuant16AsymmParams(const ANeuralNetworksOperandType & type,const char * tag)538 static bool validateQuant16AsymmParams(const ANeuralNetworksOperandType& type, const char* tag) {
539     NN_RET_CHECK(0 <= type.zeroPoint && type.zeroPoint <= 65535)
540             << tag << " invalid zeroPoint: " << type.zeroPoint;
541     NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
542     return true;
543 }
544 
validateQuantSymmParams(const ANeuralNetworksOperandType & type,const char * tag)545 static bool validateQuantSymmParams(const ANeuralNetworksOperandType& type, const char* tag) {
546     NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
547     NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
548     return true;
549 }
550 
validateNoQuantParams(const ANeuralNetworksOperandType & type,const char * tag)551 static bool validateNoQuantParams(const ANeuralNetworksOperandType& type, const char* tag) {
552     NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
553     NN_RET_CHECK_EQ(type.scale, 0.f) << tag << " scale is not zero";
554     return true;
555 }
556 
validateTensorDimensions(const ANeuralNetworksOperandType & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)557 static bool validateTensorDimensions(
558         const ANeuralNetworksOperandType& type,
559         const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag,
560         bool allowPartial) {
561     if (!allowPartial) {
562         NN_RET_CHECK_GT(type.dimensionCount, 0u) << tag << " invalid operand dimensions";
563     }
564     uint64_t size =
565             isExtensionOperandType(type.type)
566                     ? extensionOperandTypeInfo->byteSize
567                     : tableLookup(kSizeOfDataType, kSizeOfDataTypeOEM, static_cast<int>(type.type));
568     constexpr uint64_t kMaxSize = std::numeric_limits<uint32_t>::max();
569     for (uint32_t i = 0; i < type.dimensionCount; i++) {
570         if (!allowPartial) {
571             NN_RET_CHECK_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions";
572         }
573         if (type.dimensions[i] != 0) {
574             size *= type.dimensions[i];
575             NN_RET_CHECK_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize;
576         }
577     }
578     return true;
579 }
580 
validateOperandTypeHelper(const ANeuralNetworksOperandType & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)581 static bool validateOperandTypeHelper(
582         const ANeuralNetworksOperandType& type,
583         const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag,
584         bool allowPartial) {
585     NN_RET_CHECK_EQ(type.dimensionCount == 0, type.dimensions == nullptr);
586     if (isExtensionOperandType(type.type)) {
587         NN_RET_CHECK(extensionOperandTypeInfo != nullptr);
588         if (extensionOperandTypeInfo->isTensor) {
589             NN_RET_CHECK(
590                     validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
591         } else {
592             NN_RET_CHECK(validateScalarDimensions(type, tag));
593         }
594         return validateNoQuantParams(type, tag);
595     }
596 
597     NN_RET_CHECK(extensionOperandTypeInfo == nullptr);
598     NN_RET_CHECK(validCode(kNumberOfDataTypes, kNumberOfDataTypesOEM, type.type))
599             << tag << " invalid OperandType: " << type.type;
600 
601     bool isScalar = tableLookup(kScalarDataType, kScalarDataTypeOEM, type.type);
602     if (isScalar) {
603         NN_RET_CHECK(validateScalarDimensions(type, tag));
604         if (type.type != ANEURALNETWORKS_OEM_SCALAR) {  // Historically, we have allowed OEM types
605                                                         // to use quantization parameters.
606             NN_RET_CHECK(validateNoQuantParams(type, tag));
607         }
608     } else {
609         NN_RET_CHECK(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
610         if (type.type == ANEURALNETWORKS_TENSOR_QUANT8_ASYMM) {
611             NN_RET_CHECK(validateQuant8AsymmParams(type, tag));
612         } else if (type.type == ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED) {
613             NN_RET_CHECK(validateQuant8AsymmSignedParams(type, tag));
614         } else if (type.type == ANEURALNETWORKS_TENSOR_QUANT8_SYMM) {
615             NN_RET_CHECK(validateQuant8SymmParams(type, tag));
616         } else if (type.type == ANEURALNETWORKS_TENSOR_QUANT16_ASYMM) {
617             NN_RET_CHECK(validateQuant16AsymmParams(type, tag));
618         } else if (type.type == ANEURALNETWORKS_TENSOR_QUANT16_SYMM) {
619             NN_RET_CHECK(validateQuantSymmParams(type, tag));
620         } else if (type.type == ANEURALNETWORKS_TENSOR_INT32) {
621             // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters.
622         } else if (type.type == ANEURALNETWORKS_TENSOR_OEM_BYTE) {
623             // Historically, we have allowed OEM types to use quantization parameters.
624         } else {
625             NN_RET_CHECK(validateNoQuantParams(type, tag));
626         }
627     }
628 
629     return true;
630 }
631 
validateOperandType(const ANeuralNetworksOperandType & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)632 int validateOperandType(const ANeuralNetworksOperandType& type,
633                         const Extension::OperandTypeInformation* const extensionOperandTypeInfo,
634                         const char* tag, bool allowPartial) {
635     return validateOperandTypeHelper(type, extensionOperandTypeInfo, tag, allowPartial)
636                    ? ANEURALNETWORKS_NO_ERROR
637                    : ANEURALNETWORKS_BAD_DATA;
638 }
639 
validateOperandList(uint32_t count,const uint32_t * list,uint32_t operandCount,const char * tag)640 int validateOperandList(uint32_t count, const uint32_t* list, uint32_t operandCount,
641                         const char* tag) {
642     for (uint32_t i = 0; i < count; i++) {
643         if (list[i] >= operandCount) {
644             LOG(ERROR) << tag << " invalid operand index at " << i << " = " << list[i]
645                        << ", operandCount " << operandCount;
646             return ANEURALNETWORKS_BAD_DATA;
647         }
648     }
649     return ANEURALNETWORKS_NO_ERROR;
650 }
651 
validateOperationOperandTypes(const std::vector<Operand> & operands,uint32_t inOperandCount,const uint32_t * inOperandIndexes,const std::vector<OperandType> & inExpectedTypes,uint32_t outOperandCount,const uint32_t * outOperandIndexes,const std::vector<OperandType> & outExpectedInTypes)652 int validateOperationOperandTypes(const std::vector<Operand>& operands, uint32_t inOperandCount,
653                                   const uint32_t* inOperandIndexes,
654                                   const std::vector<OperandType>& inExpectedTypes,
655                                   uint32_t outOperandCount, const uint32_t* outOperandIndexes,
656                                   const std::vector<OperandType>& outExpectedInTypes) {
657     if (inOperandCount != static_cast<uint32_t>(inExpectedTypes.size()) ||
658         outOperandCount != static_cast<uint32_t>(outExpectedInTypes.size())) {
659         LOG(ERROR) << "Wrong operand count: expected " << inExpectedTypes.size() << " inputs and "
660                    << outExpectedInTypes.size() << " outputs,"
661                    << "got " << inOperandCount << " inputs and " << outOperandCount << " outputs";
662         return ANEURALNETWORKS_BAD_DATA;
663     }
664     for (uint32_t i = 0; i < inOperandCount; i++) {
665         if (operands[inOperandIndexes[i]].type != inExpectedTypes[i]) {
666             LOG(ERROR) << "Invalid input tensor type "
667                        << toString(operands[inOperandIndexes[i]].type) << " for input " << i
668                        << ", expected " << toString(inExpectedTypes[i]);
669             return ANEURALNETWORKS_BAD_DATA;
670         }
671     }
672     for (uint32_t i = 0; i < outOperandCount; i++) {
673         if (operands[outOperandIndexes[i]].type != outExpectedInTypes[i]) {
674             LOG(ERROR) << "Invalid output tensor type "
675                        << toString(operands[outOperandIndexes[i]].type) << " for input " << i
676                        << ", expected " << toString(outExpectedInTypes[i]);
677             return ANEURALNETWORKS_BAD_DATA;
678         }
679     }
680 
681     return ANEURALNETWORKS_NO_ERROR;
682 }
683 
validateHalVersion(ANeuralNetworksOperationType opType,HalVersion halVersion,HalVersion minSupportedHalVersion)684 static int validateHalVersion(ANeuralNetworksOperationType opType, HalVersion halVersion,
685                               HalVersion minSupportedHalVersion) {
686     if (halVersion < minSupportedHalVersion) {
687         LOG(ERROR) << "The given inputs and outputs for operation " << getOperationName(opType)
688                    << " are only supported in " << toString(minSupportedHalVersion)
689                    << " and later (validating using " << toString(halVersion) << ")";
690         return ANEURALNETWORKS_BAD_DATA;
691     }
692     return ANEURALNETWORKS_NO_ERROR;
693 }
694 
695 // Checks if two operands have the same types, ranks (if specified), dimensions
696 // (if specified), scales, zeroPoints, and extraParams.
compatible(const Operand & a,const Operand & b)697 static bool compatible(const Operand& a, const Operand& b) {
698     NN_RET_CHECK(a.type == b.type) << toString(a.type) << " != " << toString(b.type);
699     if (a.dimensions.size() != 0 && b.dimensions.size() != 0) {
700         NN_RET_CHECK_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions";
701         for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) {
702             if (a.dimensions[i] != 0 && b.dimensions[i] != 0) {
703                 NN_RET_CHECK_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions";
704             }
705         }
706     }
707     NN_RET_CHECK_EQ(a.scale, b.scale);
708     NN_RET_CHECK_EQ(a.zeroPoint, b.zeroPoint);
709     NN_RET_CHECK(a.extraParams == b.extraParams)
710             << toString(a.extraParams) << " != " << toString(b.extraParams);
711     return true;
712 }
713 
validateConditionOperand(const Operand & operand)714 static bool validateConditionOperand(const Operand& operand) {
715     NN_RET_CHECK(operand.type == OperandType::TENSOR_BOOL8)
716             << "Unexpected condition operand type: " << toString(operand.type);
717     NN_RET_CHECK_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton";
718     NN_RET_CHECK_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton";
719     return true;
720 }
721 
checkSubgraphValidationHelper(const SubgraphValidationHelper & helper)722 static void checkSubgraphValidationHelper(const SubgraphValidationHelper& helper) {
723     CHECK(helper.isValidSubgraphReference != nullptr);
724     CHECK(helper.getSubgraphInputCount != nullptr);
725     CHECK(helper.getSubgraphOutputCount != nullptr);
726     CHECK(helper.getSubgraphInputOperand != nullptr);
727     CHECK(helper.getSubgraphOutputOperand != nullptr);
728 }
729 
validateIfOperation(uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs,const std::vector<Operand> & operands,const SubgraphValidationHelper & helper)730 static bool validateIfOperation(uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
731                                 const uint32_t* outputs, const std::vector<Operand>& operands,
732                                 const SubgraphValidationHelper& helper) {
733     namespace op = operation_if;
734     checkSubgraphValidationHelper(helper);
735     NN_RET_CHECK_GE(inputCount, 3u) << "ANEURALNETWORKS_IF must have at least 3 inputs";
736     NN_RET_CHECK_GE(outputCount, 1u) << "ANEURALNETWORKS_IF must have at least 1 output";
737     auto validateBranchOperand = [&](const Operand& branchModelOperand) -> bool {
738         NN_RET_CHECK(helper.isValidSubgraphReference(branchModelOperand))
739                 << "Operand is not a valid subgraph reference";
740         const uint32_t branchModelInputCount = helper.getSubgraphInputCount(branchModelOperand);
741         const uint32_t branchModelOutputCount = helper.getSubgraphOutputCount(branchModelOperand);
742         NN_RET_CHECK_EQ(inputCount, op::kFirstInput + branchModelInputCount);
743         NN_RET_CHECK_EQ(outputCount, branchModelOutputCount);
744         for (uint32_t i = 0; i < branchModelInputCount; ++i) {
745             const Operand& innerOperand = *helper.getSubgraphInputOperand(branchModelOperand, i);
746             const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
747             NN_RET_CHECK(compatible(innerOperand, outerOperand));
748         }
749         for (uint32_t i = 0; i < branchModelOutputCount; ++i) {
750             const Operand& innerOperand = *helper.getSubgraphOutputOperand(branchModelOperand, i);
751             const Operand& outerOperand = operands[outputs[i]];
752             NN_RET_CHECK(compatible(innerOperand, outerOperand));
753         }
754         return true;
755     };
756     NN_RET_CHECK(validateConditionOperand(operands[inputs[op::kCondBoolOperand]]))
757             << "Validation failed for IF condition operand";
758     NN_RET_CHECK(validateBranchOperand(operands[inputs[op::kThenModelOperand]]))
759             << "Validation failed for IF then model";
760     NN_RET_CHECK(validateBranchOperand(operands[inputs[op::kElseModelOperand]]))
761             << "Validation failed for IF else model";
762     return true;
763 }
764 
validateControlFlowOperandUnknownSize(const SubgraphValidationHelper & helper,const Operand & operand)765 static bool validateControlFlowOperandUnknownSize(const SubgraphValidationHelper& helper,
766                                                   const Operand& operand) {
767     if (!helper.allowControlFlowOperationWithOperandOfUnknownSize &&
768         !isExtensionOperandType(operand.type)) {
769         NN_RET_CHECK_NE(nonExtensionOperandSizeOfData(operand.type, operand.dimensions), 0u);
770     }
771     return true;
772 }
773 
validateWhileOperation(uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs,const std::vector<Operand> & operands,const SubgraphValidationHelper & helper)774 static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs,
775                                    uint32_t outputCount, const uint32_t* outputs,
776                                    const std::vector<Operand>& operands,
777                                    const SubgraphValidationHelper& helper) {
778     // Let the loop have
779     // - m >= 1 input-output operands,
780     // - k >= 0 state-only operands, and
781     // - n >= 0 input-only operands.
782     // Then
783     // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs.
784     // - the condition model has (m + k + n) inputs and 1 output.
785     // - the body model has (m + k + n) inputs and (m + k) outputs.
786     namespace op = operation_while;
787     checkSubgraphValidationHelper(helper);
788     NN_RET_CHECK_GE(inputCount, 3u) << "ANEURALNETWORKS_WHILE must have at least 3 inputs";
789     NN_RET_CHECK_GE(outputCount, 1u) << "ANEURALNETWORKS_WHILE must have at least 1 output";
790     auto validateCondOperand = [&](const Operand& condModelOperand) -> bool {
791         NN_RET_CHECK(helper.isValidSubgraphReference(condModelOperand))
792                 << "Operand is not a valid subgraph reference";
793         const uint32_t condModelInputCount = helper.getSubgraphInputCount(condModelOperand);
794         const uint32_t condModelOutputCount = helper.getSubgraphOutputCount(condModelOperand);
795         NN_RET_CHECK_EQ(inputCount, op::kFirstInput + condModelInputCount);
796         NN_RET_CHECK_EQ(condModelOutputCount, 1u);
797         for (uint32_t i = 0; i < condModelInputCount; ++i) {
798             const Operand& innerOperand = *helper.getSubgraphInputOperand(condModelOperand, i);
799             const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
800             NN_RET_CHECK(compatible(innerOperand, outerOperand));
801             NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, innerOperand));
802             NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand));
803         }
804         NN_RET_CHECK(
805                 validateConditionOperand(*helper.getSubgraphOutputOperand(condModelOperand, 0)));
806         return true;
807     };
808     auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> bool {
809         NN_RET_CHECK(helper.isValidSubgraphReference(bodyModelOperand))
810                 << "Operand is not a valid subgraph reference";
811         const uint32_t bodyModelInputCount = helper.getSubgraphInputCount(bodyModelOperand);
812         const uint32_t bodyModelOutputCount = helper.getSubgraphOutputCount(bodyModelOperand);
813         NN_RET_CHECK_EQ(inputCount, op::kFirstInput + bodyModelInputCount);
814         NN_RET_CHECK_GE(bodyModelOutputCount, outputCount);
815         NN_RET_CHECK_GE(bodyModelInputCount, bodyModelOutputCount);
816         const uint32_t inputOutputCount = outputCount;
817         const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount;
818         const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount;
819         for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) {
820             const Operand& innerOperand = *helper.getSubgraphInputOperand(bodyModelOperand, i);
821             const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
822             NN_RET_CHECK(compatible(innerOperand, outerOperand));
823             NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, innerOperand));
824             NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand));
825         }
826         for (uint32_t i = 0; i < inputOutputCount; ++i) {
827             const Operand& innerOperand = *helper.getSubgraphOutputOperand(bodyModelOperand, i);
828             const Operand& outerOperand = operands[outputs[i]];
829             NN_RET_CHECK(compatible(innerOperand, outerOperand));
830             NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand));
831         }
832         for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
833             const Operand& inputOperand = *helper.getSubgraphInputOperand(bodyModelOperand, i);
834             const Operand& outputOperand = *helper.getSubgraphOutputOperand(bodyModelOperand, i);
835             NN_RET_CHECK(compatible(inputOperand, outputOperand));
836             NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outputOperand));
837         }
838         return true;
839     };
840     NN_RET_CHECK(validateCondOperand(operands[inputs[op::kCondModelOperand]]))
841             << "Validation failed for WHILE condition model";
842     NN_RET_CHECK(validateBodyOperand(operands[inputs[op::kBodyModelOperand]]))
843             << "Validation failed for WHILE body model";
844     return true;
845 }
846 
validateOperation(ANeuralNetworksOperationType opType,uint32_t inputCount,const uint32_t * inputIndexes,uint32_t outputCount,const uint32_t * outputIndexes,const std::vector<hal::Operand> & operands,HalVersion halVersion)847 static inline int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
848                                     const uint32_t* inputIndexes, uint32_t outputCount,
849                                     const uint32_t* outputIndexes,
850                                     const std::vector<hal::Operand>& operands,
851                                     HalVersion halVersion) {
852     if (opType == ANEURALNETWORKS_IF || opType == ANEURALNETWORKS_WHILE) {
853         NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
854         LOG(ERROR) << "This validateOperation() overload does not support control flow";
855         return ANEURALNETWORKS_BAD_DATA;
856     }
857     return validateOperation(opType, inputCount, inputIndexes, outputCount, outputIndexes, operands,
858                              halVersion, {});
859 }
860 
validateOperation(ANeuralNetworksOperationType opType,uint32_t inputCount,const uint32_t * inputIndexes,uint32_t outputCount,const uint32_t * outputIndexes,const std::vector<Operand> & operands,HalVersion halVersion,const SubgraphValidationHelper & helper)861 int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
862                       const uint32_t* inputIndexes, uint32_t outputCount,
863                       const uint32_t* outputIndexes, const std::vector<Operand>& operands,
864                       HalVersion halVersion, const SubgraphValidationHelper& helper) {
865     NN_RETURN_IF_ERROR(validateOperandList(inputCount, inputIndexes,
866                                            static_cast<uint32_t>(operands.size()),
867                                            "ANeuralNetworksModel_addOperation inputs"));
868     NN_RETURN_IF_ERROR(validateOperandList(outputCount, outputIndexes,
869                                            static_cast<uint32_t>(operands.size()),
870                                            "ANeuralNetworksModel_addOperation outputs"));
871 
872     if (isExtensionOperationType(opType)) {
873         if (halVersion < HalVersion::V1_2) {
874             LOG(ERROR)
875                     << "Extension operations are supported since HAL version 1.2, validating using "
876                     << toString(halVersion);
877             return ANEURALNETWORKS_BAD_DATA;
878         }
879         // There is no other validation we can do for an extension operation.
880         return ANEURALNETWORKS_NO_ERROR;
881     }
882 
883     auto logInvalidInOutNumber = [opType, inputCount, outputCount](int expIn, int expOut) {
884         LOG(ERROR) << "Invalid number of input operands (" << inputCount << ", expected " << expIn
885                    << ") or output operands (" << outputCount << ", expected " << expOut
886                    << ") for operation " << getOperationName(opType);
887     };
888 
889     switch (opType) {
890         case ANEURALNETWORKS_OEM_OPERATION: {
891             return ANEURALNETWORKS_NO_ERROR;
892         }
893         case ANEURALNETWORKS_RESHAPE: {
894             if (inputCount != 2 || outputCount != 1) {
895                 logInvalidInOutNumber(2, 1);
896                 return ANEURALNETWORKS_BAD_DATA;
897             }
898             auto inputType = operands[inputIndexes[0]].type;
899             std::vector<OperandType> inExpectedTypes;
900             std::vector<OperandType> outExpectedTypes;
901             if (inputType == OperandType::TENSOR_FLOAT32) {
902                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
903                 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32};
904                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
905             } else if (inputType == OperandType::TENSOR_FLOAT16) {
906                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
907                 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32};
908                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
909             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
910                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
911                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32};
912                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
913             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
914                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
915                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
916                                    OperandType::TENSOR_INT32};
917                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
918             } else {
919                 LOG(ERROR) << "Unsupported input tensor type for operation "
920                            << getOperationName(opType);
921                 return ANEURALNETWORKS_BAD_DATA;
922             }
923             const auto inputRank = operands[inputIndexes[0]].dimensions.size();
924             if (inputRank > 4) {
925                 LOG(ERROR) << "Unsupported input tensor rank for operation "
926                            << getOperationName(opType);
927                 return ANEURALNETWORKS_BAD_DATA;
928             }
929             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
930                                                  inExpectedTypes, outputCount, outputIndexes,
931                                                  outExpectedTypes);
932         }
933         case ANEURALNETWORKS_DEPTH_TO_SPACE: {
934             if ((inputCount != 3 && inputCount != 2) || outputCount != 1) {
935                 LOG(ERROR) << "Invalid number of input operands (" << inputCount
936                            << ", expected 3 or 2) or output operands (" << outputCount
937                            << ", expected 1) for operation " << getOperationName(opType);
938                 return ANEURALNETWORKS_BAD_DATA;
939             }
940             auto inputType = operands[inputIndexes[0]].type;
941             std::vector<OperandType> inExpectedTypes;
942             std::vector<OperandType> outExpectedTypes;
943             if (inputType == OperandType::TENSOR_FLOAT32) {
944                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
945                 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
946                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
947             } else if (inputType == OperandType::TENSOR_FLOAT16) {
948                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
949                 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
950                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
951             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
952                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
953                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
954                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
955             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
956                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
957                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
958                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
959             } else {
960                 LOG(ERROR) << "Unsupported input tensor type for operation "
961                            << getOperationName(opType);
962                 return ANEURALNETWORKS_BAD_DATA;
963             }
964             if (inputCount == 3) {
965                 inExpectedTypes.push_back(OperandType::BOOL);
966                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
967             } else {
968                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
969             }
970             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
971                                                  inExpectedTypes, outputCount, outputIndexes,
972                                                  outExpectedTypes);
973         }
974         case ANEURALNETWORKS_SPACE_TO_DEPTH: {
975             if ((inputCount != 3 && inputCount != 2) || outputCount != 1) {
976                 LOG(ERROR) << "Invalid number of input operands (" << inputCount
977                            << ", expected 3 or 2) or output operands (" << outputCount
978                            << ", expected 1) for operation " << getOperationName(opType);
979                 return ANEURALNETWORKS_BAD_DATA;
980             }
981             auto inputType = operands[inputIndexes[0]].type;
982             std::vector<OperandType> inExpectedTypes;
983             std::vector<OperandType> outExpectedTypes;
984             if (inputType == OperandType::TENSOR_FLOAT32) {
985                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
986                 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
987                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
988             } else if (inputType == OperandType::TENSOR_FLOAT16) {
989                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
990                 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
991                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
992             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
993                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
994                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
995                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
996             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
997                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
998                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
999                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1000             } else {
1001                 LOG(ERROR) << "Unsupported input tensor type for operation "
1002                            << getOperationName(opType);
1003                 return ANEURALNETWORKS_BAD_DATA;
1004             }
1005             if (inputCount == 3) {
1006                 inExpectedTypes.push_back(OperandType::BOOL);
1007                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1008             } else {
1009                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1010             }
1011             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1012                                                  inExpectedTypes, outputCount, outputIndexes,
1013                                                  outExpectedTypes);
1014         }
1015         case ANEURALNETWORKS_EMBEDDING_LOOKUP: {
1016             if (inputCount != 2 || outputCount != 1) {
1017                 logInvalidInOutNumber(2, 1);
1018                 return ANEURALNETWORKS_BAD_DATA;
1019             }
1020             auto inputType = operands[inputIndexes[1]].type;
1021             if (inputType != OperandType::TENSOR_FLOAT16 &&
1022                 inputType != OperandType::TENSOR_FLOAT32 &&
1023                 inputType != OperandType::TENSOR_INT32 &&
1024                 inputType != OperandType::TENSOR_QUANT8_ASYMM &&
1025                 inputType != OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1026                 LOG(ERROR) << "Unsupported input tensor type for operation "
1027                            << getOperationName(opType);
1028                 return ANEURALNETWORKS_BAD_DATA;
1029             }
1030             std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, inputType};
1031             std::vector<OperandType> outExpectedTypes = {inputType};
1032             if (inputType == OperandType::TENSOR_FLOAT16 ||
1033                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1034                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1035             } else if (inputType == OperandType::TENSOR_INT32 ||
1036                        inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1037                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1038             } else {
1039                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1040             }
1041             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1042                                                  inExpectedTypes, outputCount, outputIndexes,
1043                                                  outExpectedTypes);
1044         }
1045         case ANEURALNETWORKS_HASHTABLE_LOOKUP: {
1046             if (inputCount != 3 || outputCount != 2) {
1047                 logInvalidInOutNumber(3, 2);
1048                 return ANEURALNETWORKS_BAD_DATA;
1049             }
1050             auto inputType = operands[inputIndexes[2]].type;
1051             if (inputType != OperandType::TENSOR_FLOAT32 &&
1052                 inputType != OperandType::TENSOR_INT32 &&
1053                 inputType != OperandType::TENSOR_QUANT8_ASYMM) {
1054                 LOG(ERROR) << "Unsupported input tensor type for operation "
1055                            << getOperationName(opType);
1056                 return ANEURALNETWORKS_BAD_DATA;
1057             }
1058             std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32,
1059                                                         OperandType::TENSOR_INT32, inputType};
1060             std::vector<OperandType> outExpectedTypes = {inputType,
1061                                                          OperandType::TENSOR_QUANT8_ASYMM};
1062             NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1063             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1064                                                  inExpectedTypes, outputCount, outputIndexes,
1065                                                  outExpectedTypes);
1066         }
1067         case ANEURALNETWORKS_LSH_PROJECTION: {
1068             if (inputCount != 4 || outputCount != 1) {
1069                 logInvalidInOutNumber(4, 1);
1070                 return ANEURALNETWORKS_BAD_DATA;
1071             }
1072             auto inputType = operands[inputIndexes[1]].type;
1073             if (inputType != OperandType::TENSOR_FLOAT16 &&
1074                 inputType != OperandType::TENSOR_FLOAT32 &&
1075                 inputType != OperandType::TENSOR_INT32 &&
1076                 inputType != OperandType::TENSOR_QUANT8_ASYMM) {
1077                 LOG(ERROR) << "Unsupported input tensor type for operation "
1078                            << getOperationName(opType);
1079                 return ANEURALNETWORKS_BAD_DATA;
1080             }
1081             auto hashType = operands[inputIndexes[0]].type;
1082             std::vector<OperandType> inExpectedTypes;
1083             if (hashType == OperandType::TENSOR_FLOAT16) {
1084                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1085                 inExpectedTypes = {
1086                         OperandType::TENSOR_FLOAT16,
1087                         inputType,
1088                         OperandType::TENSOR_FLOAT16,
1089                         OperandType::INT32,
1090                 };
1091             } else if (hashType == OperandType::TENSOR_FLOAT32) {
1092                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1093                 inExpectedTypes = {
1094                         OperandType::TENSOR_FLOAT32,
1095                         inputType,
1096                         OperandType::TENSOR_FLOAT32,
1097                         OperandType::INT32,
1098                 };
1099             } else {
1100                 LOG(ERROR) << "Unsupported hash tensor type for operation "
1101                            << getOperationName(opType);
1102                 return ANEURALNETWORKS_BAD_DATA;
1103             }
1104             std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
1105             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1106                                                  inExpectedTypes, outputCount, outputIndexes,
1107                                                  outExpectedTypes);
1108         }
1109         case ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM: {
1110             const uint32_t kNumOutputs = 2;
1111             const uint32_t kNumOutputsMerged = 1;
1112             const uint32_t kNumOutputsWithState = 6;
1113             const uint32_t kNumOutputsMergedWithState = 5;
1114             if (inputCount != 61 ||
1115                 (outputCount != kNumOutputs && outputCount != kNumOutputsMerged &&
1116                  outputCount != kNumOutputsWithState &&
1117                  outputCount != kNumOutputsMergedWithState)) {
1118                 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1119                            << ", expected 61) or output operands (" << outputCount
1120                            << ", expected 1, 2, 5 or 6) for operation " << getOperationName(opType);
1121                 return ANEURALNETWORKS_BAD_DATA;
1122             }
1123 
1124             std::vector<OperandType> inExpectedTypes;
1125             auto inputType = operands[inputIndexes[0]].type;
1126             if (inputType != OperandType::TENSOR_FLOAT32 &&
1127                 inputType != OperandType::TENSOR_FLOAT16) {
1128                 LOG(ERROR) << "Unsupported input tensor type for operation "
1129                            << getOperationName(opType);
1130                 return ANEURALNETWORKS_BAD_DATA;
1131             }
1132 
1133             inExpectedTypes = {};
1134             for (int i = 0; i < 48; ++i) {
1135                 inExpectedTypes.push_back(inputType);
1136             }
1137             inExpectedTypes.push_back(OperandType::INT32);
1138             inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
1139                                               ? OperandType::FLOAT32
1140                                               : OperandType::FLOAT16);
1141             inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
1142                                               ? OperandType::FLOAT32
1143                                               : OperandType::FLOAT16);
1144             inExpectedTypes.push_back(OperandType::BOOL);
1145             inExpectedTypes.push_back(OperandType::BOOL);
1146             for (int i = 0; i < 8; ++i) {
1147                 inExpectedTypes.push_back(inputType);
1148             }
1149 
1150             HalVersion minSupportedHalVersion = HalVersion::V1_2;
1151             if (outputCount == kNumOutputsWithState || outputCount == kNumOutputsMergedWithState) {
1152                 minSupportedHalVersion = HalVersion::V1_3;
1153             }
1154             NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, minSupportedHalVersion));
1155             std::vector<OperandType> outExpectedTypes(outputCount, inputType);
1156             auto status = validateOperationOperandTypes(operands, inputCount, inputIndexes,
1157                                                         inExpectedTypes, outputCount, outputIndexes,
1158                                                         outExpectedTypes);
1159             return status;
1160         }
1161         case ANEURALNETWORKS_LSTM: {
1162             if ((inputCount != 23 && inputCount != 27) || outputCount != 4) {
1163                 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1164                            << ", expected 23 or 27) or output operands (" << outputCount
1165                            << ", expected 4) for operation " << getOperationName(opType);
1166                 return ANEURALNETWORKS_BAD_DATA;
1167             }
1168             std::vector<OperandType> inExpectedTypes;
1169             std::vector<OperandType> outExpectedTypes;
1170             auto inputType = operands[inputIndexes[0]].type;
1171             if (inputType != OperandType::TENSOR_FLOAT32 &&
1172                 inputType != OperandType::TENSOR_FLOAT16) {
1173                 LOG(ERROR) << "Unsupported input tensor type for operation "
1174                            << getOperationName(opType);
1175                 return ANEURALNETWORKS_BAD_DATA;
1176             }
1177 
1178             inExpectedTypes = {inputType,         inputType, inputType, inputType, inputType,
1179                                inputType,         inputType, inputType, inputType, inputType,
1180                                inputType,         inputType, inputType, inputType, inputType,
1181                                inputType,         inputType, inputType, inputType, inputType,
1182                                OperandType::INT32};
1183             if (inputType == OperandType::TENSOR_FLOAT32) {
1184                 inExpectedTypes.push_back(OperandType::FLOAT32);
1185                 inExpectedTypes.push_back(OperandType::FLOAT32);
1186             } else {
1187                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1188                 inExpectedTypes.push_back(OperandType::FLOAT16);
1189                 inExpectedTypes.push_back(OperandType::FLOAT16);
1190             }
1191 
1192             outExpectedTypes = {inputType, inputType, inputType, inputType};
1193             if (inputCount == 23) {
1194                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1195             } else {
1196                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1197                 for (int i = 0; i < 4; ++i) {
1198                     inExpectedTypes.push_back(inputType);
1199                 }
1200             }
1201             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1202                                                  inExpectedTypes, outputCount, outputIndexes,
1203                                                  outExpectedTypes);
1204         }
1205         case ANEURALNETWORKS_QUANTIZED_16BIT_LSTM: {
1206             if (inputCount != 15 || outputCount != 2) {
1207                 logInvalidInOutNumber(15, 2);
1208                 return ANEURALNETWORKS_BAD_DATA;
1209             }
1210             NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1211             std::vector<OperandType> inExpectedTypes = {
1212                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
1213                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
1214                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
1215                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
1216                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32,
1217                     OperandType::TENSOR_INT32,        OperandType::TENSOR_INT32,
1218                     OperandType::TENSOR_INT32,        OperandType::TENSOR_QUANT16_SYMM,
1219                     OperandType::TENSOR_QUANT8_ASYMM};
1220             std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_QUANT16_SYMM,
1221                                                          OperandType::TENSOR_QUANT8_ASYMM};
1222             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1223                                                  inExpectedTypes, outputCount, outputIndexes,
1224                                                  outExpectedTypes);
1225         }
1226         case ANEURALNETWORKS_RANDOM_MULTINOMIAL: {
1227             if (inputCount != 3 || outputCount != 1) {
1228                 logInvalidInOutNumber(3, 1);
1229                 return ANEURALNETWORKS_BAD_DATA;
1230             }
1231             OperandType inputType = operands[inputIndexes[0]].type;
1232             std::vector<OperandType> inExpectedTypes;
1233             if (inputType == OperandType::TENSOR_FLOAT32 ||
1234                 inputType == OperandType::TENSOR_FLOAT16) {
1235                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1236                 inExpectedTypes = {
1237                         inputType,
1238                         OperandType::INT32,
1239                         OperandType::TENSOR_INT32,
1240                 };
1241             } else {
1242                 LOG(ERROR) << "Unsupported input tensor type for operation "
1243                            << getOperationName(opType);
1244                 return ANEURALNETWORKS_BAD_DATA;
1245             }
1246             std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
1247             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1248                                                  inExpectedTypes, outputCount, outputIndexes,
1249                                                  outExpectedTypes);
1250         }
1251         case ANEURALNETWORKS_RNN: {
1252             if (inputCount != 6 || outputCount != 2) {
1253                 logInvalidInOutNumber(6, 2);
1254                 return ANEURALNETWORKS_BAD_DATA;
1255             }
1256             OperandType inputType = operands[inputIndexes[0]].type;
1257             std::vector<OperandType> inExpectedTypes;
1258             std::vector<OperandType> outExpectedTypes;
1259             if (inputType == OperandType::TENSOR_FLOAT32) {
1260                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1261                 inExpectedTypes = {
1262                         OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
1263                         OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
1264                         OperandType::TENSOR_FLOAT32, OperandType::INT32,
1265                 };
1266                 outExpectedTypes = {
1267                         OperandType::TENSOR_FLOAT32,
1268                         OperandType::TENSOR_FLOAT32,
1269                 };
1270             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1271                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1272                 inExpectedTypes = {
1273                         OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
1274                         OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
1275                         OperandType::TENSOR_FLOAT16, OperandType::INT32,
1276                 };
1277                 outExpectedTypes = {
1278                         OperandType::TENSOR_FLOAT16,
1279                         OperandType::TENSOR_FLOAT16,
1280                 };
1281             } else {
1282                 LOG(ERROR) << "Unsupported input tensor type for operation "
1283                            << getOperationName(opType);
1284                 return ANEURALNETWORKS_BAD_DATA;
1285             }
1286             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1287                                                  inExpectedTypes, outputCount, outputIndexes,
1288                                                  outExpectedTypes);
1289         }
1290         case ANEURALNETWORKS_SVDF: {
1291             if (inputCount != 7 || outputCount != 2) {
1292                 logInvalidInOutNumber(7, 2);
1293                 return ANEURALNETWORKS_BAD_DATA;
1294             }
1295             OperandType inputType = operands[inputIndexes[0]].type;
1296             if (inputType == OperandType::TENSOR_FLOAT32) {
1297                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1298 
1299             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1300                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1301             } else {
1302                 LOG(ERROR) << "Unsupported input tensor type for operation "
1303                            << getOperationName(opType);
1304                 return ANEURALNETWORKS_BAD_DATA;
1305             }
1306             std::vector<OperandType> inExpectedTypes = {
1307                     inputType, inputType,          inputType,          inputType,
1308                     inputType, OperandType::INT32, OperandType::INT32,
1309             };
1310             std::vector<OperandType> outExpectedTypes = {inputType, inputType};
1311             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1312                                                  inExpectedTypes, outputCount, outputIndexes,
1313                                                  outExpectedTypes);
1314         }
1315         case ANEURALNETWORKS_BATCH_TO_SPACE_ND: {
1316             if ((inputCount != 3 && inputCount != 2) || outputCount != 1) {
1317                 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1318                            << ", expected 3 or 2) or output operands (" << outputCount
1319                            << ", expected 1) for operation " << getOperationName(opType);
1320                 return ANEURALNETWORKS_BAD_DATA;
1321             }
1322             auto inputType = operands[inputIndexes[0]].type;
1323             std::vector<OperandType> inExpectedTypes;
1324             std::vector<OperandType> outExpectedTypes;
1325             if (inputType == OperandType::TENSOR_FLOAT32) {
1326                 inExpectedTypes = {
1327                         OperandType::TENSOR_FLOAT32,
1328                         OperandType::TENSOR_INT32,
1329                 };
1330                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1331             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1332                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1333                 inExpectedTypes = {
1334                         OperandType::TENSOR_FLOAT16,
1335                         OperandType::TENSOR_INT32,
1336                 };
1337                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1338             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1339                 inExpectedTypes = {
1340                         OperandType::TENSOR_QUANT8_ASYMM,
1341                         OperandType::TENSOR_INT32,
1342                 };
1343                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1344             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1345                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1346                 inExpectedTypes = {
1347                         OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
1348                         OperandType::TENSOR_INT32,
1349                 };
1350                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1351             } else {
1352                 LOG(ERROR) << "Unsupported input tensor type for operation "
1353                            << getOperationName(opType);
1354                 return ANEURALNETWORKS_BAD_DATA;
1355             }
1356             if (inputCount == 3) {
1357                 inExpectedTypes.push_back(OperandType::BOOL);
1358                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1359             } else {
1360                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1361             }
1362             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1363                                                  inExpectedTypes, outputCount, outputIndexes,
1364                                                  outExpectedTypes);
1365         }
1366         case ANEURALNETWORKS_SPACE_TO_BATCH_ND: {
1367             if ((inputCount != 4 && inputCount != 3) || outputCount != 1) {
1368                 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1369                            << ", expected 4 or 3) or output operands (" << outputCount
1370                            << ", expected 1) for operation " << getOperationName(opType);
1371                 return ANEURALNETWORKS_BAD_DATA;
1372             }
1373             auto inputType = operands[inputIndexes[0]].type;
1374             std::vector<OperandType> inExpectedTypes;
1375             std::vector<OperandType> outExpectedTypes;
1376             if (inputType == OperandType::TENSOR_FLOAT32) {
1377                 inExpectedTypes = {
1378                         OperandType::TENSOR_FLOAT32,
1379                         OperandType::TENSOR_INT32,
1380                         OperandType::TENSOR_INT32,
1381                 };
1382                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1383             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1384                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1385                 inExpectedTypes = {
1386                         OperandType::TENSOR_FLOAT16,
1387                         OperandType::TENSOR_INT32,
1388                         OperandType::TENSOR_INT32,
1389                 };
1390                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1391             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1392                 if (operands[inputIndexes[0]].zeroPoint != 0) {
1393                     NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1394                 }
1395                 inExpectedTypes = {
1396                         OperandType::TENSOR_QUANT8_ASYMM,
1397                         OperandType::TENSOR_INT32,
1398                         OperandType::TENSOR_INT32,
1399                 };
1400                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1401             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1402                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1403                 inExpectedTypes = {
1404                         OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
1405                         OperandType::TENSOR_INT32,
1406                         OperandType::TENSOR_INT32,
1407                 };
1408                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1409             } else {
1410                 LOG(ERROR) << "Unsupported input tensor type for operation "
1411                            << getOperationName(opType);
1412                 return ANEURALNETWORKS_BAD_DATA;
1413             }
1414             if (inputCount == 4) {
1415                 inExpectedTypes.push_back(OperandType::BOOL);
1416                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1417             } else {
1418                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1419             }
1420             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1421                                                  inExpectedTypes, outputCount, outputIndexes,
1422                                                  outExpectedTypes);
1423         }
1424         case ANEURALNETWORKS_PAD: {
1425             if (inputCount != 2 || outputCount != 1) {
1426                 logInvalidInOutNumber(2, 1);
1427                 return ANEURALNETWORKS_BAD_DATA;
1428             }
1429             auto inputType = operands[inputIndexes[0]].type;
1430             std::vector<OperandType> inExpectedTypes;
1431             std::vector<OperandType> outExpectedTypes;
1432             if (inputType == OperandType::TENSOR_FLOAT32) {
1433                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1434                 inExpectedTypes = {
1435                         OperandType::TENSOR_FLOAT32,
1436                         OperandType::TENSOR_INT32,
1437                 };
1438                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1439             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1440                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1441                 inExpectedTypes = {
1442                         OperandType::TENSOR_FLOAT16,
1443                         OperandType::TENSOR_INT32,
1444                 };
1445                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1446             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1447                        inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1448                 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1449                     NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1450                 } else {
1451                     if (operands[inputIndexes[0]].zeroPoint == 0) {
1452                         NN_RETURN_IF_ERROR(
1453                                 validateHalVersion(opType, halVersion, HalVersion::V1_1));
1454                     } else {
1455                         NN_RETURN_IF_ERROR(
1456                                 validateHalVersion(opType, halVersion, HalVersion::V1_2));
1457                     }
1458                 }
1459                 inExpectedTypes = {
1460                         inputType,
1461                         OperandType::TENSOR_INT32,
1462                 };
1463                 outExpectedTypes = {inputType};
1464             } else {
1465                 LOG(ERROR) << "Unsupported input tensor type for operation "
1466                            << getOperationName(opType);
1467                 return ANEURALNETWORKS_BAD_DATA;
1468             }
1469             const auto inputRank = operands[inputIndexes[0]].dimensions.size();
1470             if (inputRank > 4) {
1471                 LOG(ERROR) << "Unsupported input tensor rank for operation "
1472                            << getOperationName(opType);
1473                 return ANEURALNETWORKS_BAD_DATA;
1474             }
1475             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1476                                                  inExpectedTypes, outputCount, outputIndexes,
1477                                                  outExpectedTypes);
1478         }
1479         case ANEURALNETWORKS_PAD_V2: {
1480             if (inputCount != 3 || outputCount != 1) {
1481                 logInvalidInOutNumber(3, 1);
1482                 return ANEURALNETWORKS_BAD_DATA;
1483             }
1484             auto inputType = operands[inputIndexes[0]].type;
1485             std::vector<OperandType> inExpectedTypes;
1486             std::vector<OperandType> outExpectedTypes;
1487             if (inputType == OperandType::TENSOR_FLOAT32) {
1488                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1489                 inExpectedTypes = {
1490                         OperandType::TENSOR_FLOAT32,
1491                         OperandType::TENSOR_INT32,
1492                         OperandType::FLOAT32,
1493                 };
1494                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1495             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1496                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1497                 inExpectedTypes = {
1498                         OperandType::TENSOR_FLOAT16,
1499                         OperandType::TENSOR_INT32,
1500                         OperandType::FLOAT16,
1501                 };
1502                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1503             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1504                        inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1505                 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1506                     NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1507                 } else {
1508                     NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1509                 }
1510                 inExpectedTypes = {
1511                         inputType,
1512                         OperandType::TENSOR_INT32,
1513                         OperandType::INT32,
1514                 };  // TODO(b/116699425): Make it UINT8.
1515                 outExpectedTypes = {inputType};
1516             } else {
1517                 LOG(ERROR) << "Unsupported input tensor type for operation "
1518                            << getOperationName(opType);
1519                 return ANEURALNETWORKS_BAD_DATA;
1520             }
1521             const auto inputRank = operands[inputIndexes[0]].dimensions.size();
1522             if (inputRank > 4) {
1523                 LOG(ERROR) << "Unsupported input tensor rank for operation "
1524                            << getOperationName(opType);
1525                 return ANEURALNETWORKS_BAD_DATA;
1526             }
1527             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1528                                                  inExpectedTypes, outputCount, outputIndexes,
1529                                                  outExpectedTypes);
1530         }
1531         case ANEURALNETWORKS_CAST: {
1532             if (inputCount != 1 || outputCount != 1) {
1533                 logInvalidInOutNumber(1, 1);
1534                 return ANEURALNETWORKS_BAD_DATA;
1535             }
1536             auto inputOperand = operands[inputIndexes[0]];
1537             auto outputOperand = operands[outputIndexes[0]];
1538             auto inputType = inputOperand.type;
1539             auto outputType = outputOperand.type;
1540             std::vector<OperandType> inExpectedTypes;
1541             std::vector<OperandType> outExpectedTypes;
1542             if ((inputType == OperandType::TENSOR_FLOAT16 ||
1543                  inputType == OperandType::TENSOR_FLOAT32 ||
1544                  inputType == OperandType::TENSOR_INT32 ||
1545                  inputType == OperandType::TENSOR_QUANT8_ASYMM) &&
1546                 (outputType == OperandType::TENSOR_FLOAT16 ||
1547                  outputType == OperandType::TENSOR_FLOAT32 ||
1548                  outputType == OperandType::TENSOR_INT32 ||
1549                  outputType == OperandType::TENSOR_QUANT8_ASYMM)) {
1550                 inExpectedTypes = {inputType};
1551                 outExpectedTypes = {outputType};
1552                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1553             } else if (inputType == OperandType::TENSOR_BOOL8 ||
1554                        inputType == OperandType::TENSOR_QUANT16_ASYMM ||
1555                        inputType == OperandType::TENSOR_QUANT16_SYMM ||
1556                        inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
1557                        inputType == OperandType::TENSOR_QUANT8_SYMM) {
1558                 inExpectedTypes = {inputType};
1559                 outExpectedTypes = {inputType};  // Only identity CAST is supported.
1560                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1561             } else {
1562                 LOG(ERROR) << "Unsupported data type for operation " << getOperationName(opType);
1563                 return ANEURALNETWORKS_BAD_DATA;
1564             }
1565             // Validate that output shape is equal to input shape if dimensions
1566             // are already known.
1567             auto getNumberOfElements = [](const hardware::hidl_vec<uint32_t>& dims) {
1568                 if (dims.size() == 0) {
1569                     return 0;
1570                 }
1571                 return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
1572             };
1573             if (inputOperand.dimensions.size() != 0 && outputOperand.dimensions.size() != 0 &&
1574                 getNumberOfElements(outputOperand.dimensions) != 0 &&
1575                 inputOperand.dimensions != outputOperand.dimensions) {
1576                 return ANEURALNETWORKS_BAD_DATA;
1577             }
1578             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1579                                                  inExpectedTypes, outputCount, outputIndexes,
1580                                                  outExpectedTypes);
1581         }
1582         case ANEURALNETWORKS_MEAN: {
1583             if (inputCount != 3 || outputCount != 1) {
1584                 logInvalidInOutNumber(3, 1);
1585                 return ANEURALNETWORKS_BAD_DATA;
1586             }
1587             const auto inputRank = operands[inputIndexes[0]].dimensions.size();
1588             if (inputRank > 4) {
1589                 LOG(ERROR) << "Unsupported input tensor rank for operation "
1590                            << getOperationName(opType);
1591                 return ANEURALNETWORKS_BAD_DATA;
1592             }
1593             auto inputType = operands[inputIndexes[0]].type;
1594             if (inputType == OperandType::TENSOR_FLOAT32) {
1595                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1596             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1597                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1598             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1599                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1600             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1601                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1602             } else {
1603                 LOG(ERROR) << "Unsupported input tensor type for operation "
1604                            << getOperationName(opType);
1605                 return ANEURALNETWORKS_BAD_DATA;
1606             }
1607             std::vector<OperandType> inExpectedTypes = {inputType, OperandType::TENSOR_INT32,
1608                                                         OperandType::INT32};
1609             std::vector<OperandType> outExpectedTypes = {inputType};
1610             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1611                                                  inExpectedTypes, outputCount, outputIndexes,
1612                                                  outExpectedTypes);
1613         }
1614         case ANEURALNETWORKS_ARGMAX:
1615         case ANEURALNETWORKS_ARGMIN: {
1616             if (inputCount != 2 || outputCount != 1) {
1617                 logInvalidInOutNumber(2, 1);
1618                 return ANEURALNETWORKS_BAD_DATA;
1619             }
1620             auto inputType = operands[inputIndexes[0]].type;
1621             std::vector<OperandType> inExpectedTypes;
1622             std::vector<OperandType> outExpectedTypes;
1623             if (inputType == OperandType::TENSOR_FLOAT16 ||
1624                 inputType == OperandType::TENSOR_FLOAT32 ||
1625                 inputType == OperandType::TENSOR_INT32 ||
1626                 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1627                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1628                 inExpectedTypes = {inputType, OperandType::INT32};
1629                 outExpectedTypes = {OperandType::TENSOR_INT32};
1630             } else {
1631                 LOG(ERROR) << "Unsupported input tensor type for operation "
1632                            << getOperationName(opType);
1633                 return ANEURALNETWORKS_BAD_DATA;
1634             }
1635             NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1636             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1637                                                  inExpectedTypes, outputCount, outputIndexes,
1638                                                  outExpectedTypes);
1639         }
1640         case ANEURALNETWORKS_EXPAND_DIMS: {
1641             if (inputCount != 2 || outputCount != 1) {
1642                 logInvalidInOutNumber(2, 1);
1643                 return ANEURALNETWORKS_BAD_DATA;
1644             }
1645             auto inputType = operands[inputIndexes[0]].type;
1646             std::vector<OperandType> inExpectedTypes;
1647             std::vector<OperandType> outExpectedTypes;
1648             if (inputType == OperandType::TENSOR_FLOAT16 ||
1649                 inputType == OperandType::TENSOR_FLOAT32 ||
1650                 inputType == OperandType::TENSOR_INT32 ||
1651                 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1652                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1653                 inExpectedTypes = {inputType, OperandType::INT32};
1654                 outExpectedTypes = {inputType};
1655             } else {
1656                 LOG(ERROR) << "Unsupported input tensor type for operation "
1657                            << getOperationName(opType);
1658                 return ANEURALNETWORKS_BAD_DATA;
1659             }
1660             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1661                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1662             } else {
1663                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1664             }
1665             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1666                                                  inExpectedTypes, outputCount, outputIndexes,
1667                                                  outExpectedTypes);
1668         }
1669         case ANEURALNETWORKS_SPLIT: {
1670             if (inputCount != 3) {
1671                 LOG(ERROR) << "Invalid number of input operands (" << inputCount << ", expected 3)"
1672                            << getOperationName(opType);
1673                 return ANEURALNETWORKS_BAD_DATA;
1674             }
1675             auto inputType = operands[inputIndexes[0]].type;
1676             if (inputType != OperandType::TENSOR_FLOAT16 &&
1677                 inputType != OperandType::TENSOR_FLOAT32 &&
1678                 inputType != OperandType::TENSOR_INT32 &&
1679                 inputType != OperandType::TENSOR_QUANT8_ASYMM &&
1680                 inputType != OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1681                 LOG(ERROR) << "Unsupported input tensor type for operation "
1682                            << getOperationName(opType);
1683                 return ANEURALNETWORKS_BAD_DATA;
1684             }
1685             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1686                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1687             } else {
1688                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1689             }
1690             std::vector<OperandType> inExpectedTypes = {inputType, OperandType::INT32,
1691                                                         OperandType::INT32};
1692             std::vector<OperandType> outExpectedTypes(outputCount, inputType);
1693             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1694                                                  inExpectedTypes, outputCount, outputIndexes,
1695                                                  outExpectedTypes);
1696         }
1697         case ANEURALNETWORKS_MAXIMUM:
1698         case ANEURALNETWORKS_MINIMUM: {
1699             if (inputCount != 2 || outputCount != 1) {
1700                 logInvalidInOutNumber(2, 1);
1701                 return ANEURALNETWORKS_BAD_DATA;
1702             }
1703             std::vector<OperandType> inExpectedTypes;
1704             std::vector<OperandType> outExpectedTypes;
1705             OperandType inputType = operands[inputIndexes[0]].type;
1706             if (inputType == OperandType::TENSOR_FLOAT16 ||
1707                 inputType == OperandType::TENSOR_FLOAT32 ||
1708                 inputType == OperandType::TENSOR_INT32 ||
1709                 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1710                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1711                 inExpectedTypes = {inputType, inputType};
1712                 outExpectedTypes = {inputType};
1713             } else {
1714                 LOG(ERROR) << "Unsupported input tensor type for operation "
1715                            << getOperationName(opType);
1716                 return ANEURALNETWORKS_BAD_DATA;
1717             }
1718             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1719                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1720             } else {
1721                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1722             }
1723             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1724                                                  inExpectedTypes, outputCount, outputIndexes,
1725                                                  outExpectedTypes);
1726         }
1727         case ANEURALNETWORKS_GROUPED_CONV_2D: {
1728             if ((inputCount != 12 && inputCount != 9) || outputCount != 1) {
1729                 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1730                            << ", expected 12 or 9) or output operands (" << outputCount
1731                            << ", expected 1) for operation " << getOperationName(opType);
1732                 return ANEURALNETWORKS_BAD_DATA;
1733             }
1734             auto inputType = operands[inputIndexes[0]].type;
1735             auto filterType = operands[inputIndexes[1]].type;
1736             std::vector<OperandType> inExpectedTypes;
1737             std::vector<OperandType> outExpectedTypes;
1738             if (inputType == OperandType::TENSOR_FLOAT32) {
1739                 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
1740                                    OperandType::TENSOR_FLOAT32, OperandType::INT32,
1741                                    OperandType::INT32,          OperandType::INT32,
1742                                    OperandType::INT32,          OperandType::INT32};
1743                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1744             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1745                 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
1746                                    OperandType::TENSOR_FLOAT16, OperandType::INT32,
1747                                    OperandType::INT32,          OperandType::INT32,
1748                                    OperandType::INT32,          OperandType::INT32};
1749                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1750             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1751                        inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1752                 if (filterType != inputType &&
1753                     filterType != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
1754                     LOG(ERROR) << "Unsupported filter tensor type for operation "
1755                                << getOperationName(opType);
1756                     return ANEURALNETWORKS_BAD_DATA;
1757                 }
1758 
1759                 if (filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL &&
1760                     operands[inputIndexes[1]].extraParams.channelQuant().channelDim != 0) {
1761                     LOG(ERROR) << "Unsupported filter tensor channel dimension for operation "
1762                                << getOperationName(opType);
1763                     return ANEURALNETWORKS_BAD_DATA;
1764                 }
1765 
1766                 inExpectedTypes = {
1767                         inputType,          filterType,         OperandType::TENSOR_INT32,
1768                         OperandType::INT32, OperandType::INT32, OperandType::INT32,
1769                         OperandType::INT32, OperandType::INT32};
1770                 outExpectedTypes = {inputType};
1771             } else {
1772                 LOG(ERROR) << "Unsupported input tensor type for operation "
1773                            << getOperationName(opType);
1774                 return ANEURALNETWORKS_BAD_DATA;
1775             }
1776 
1777             if (inputCount == 12) {
1778                 std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
1779                 inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
1780                                        explicitScalarTypes.end());
1781             }
1782             inExpectedTypes.push_back(OperandType::BOOL);
1783             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1784                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1785             } else {
1786                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1787             }
1788             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1789                                                  inExpectedTypes, outputCount, outputIndexes,
1790                                                  outExpectedTypes);
1791         }
1792         case ANEURALNETWORKS_TILE: {
1793             if (inputCount != 2 || outputCount != 1) {
1794                 logInvalidInOutNumber(2, 1);
1795                 return ANEURALNETWORKS_BAD_DATA;
1796             }
1797             auto inputType = operands[inputIndexes[0]].type;
1798             std::vector<OperandType> inExpectedTypes;
1799             std::vector<OperandType> outExpectedTypes;
1800             if (inputType == OperandType::TENSOR_FLOAT16 ||
1801                 inputType == OperandType::TENSOR_FLOAT32 ||
1802                 inputType == OperandType::TENSOR_INT32 ||
1803                 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1804                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1805                 inExpectedTypes = {inputType, OperandType::TENSOR_INT32};
1806                 outExpectedTypes = {inputType};
1807             } else {
1808                 LOG(ERROR) << "Unsupported input tensor type for operation "
1809                            << getOperationName(opType);
1810                 return ANEURALNETWORKS_BAD_DATA;
1811             }
1812             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1813                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1814             } else {
1815                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1816             }
1817             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1818                                                  inExpectedTypes, outputCount, outputIndexes,
1819                                                  outExpectedTypes);
1820         }
1821         case ANEURALNETWORKS_POW: {
1822             if (inputCount != 2 || outputCount != 1) {
1823                 logInvalidInOutNumber(2, 1);
1824                 return ANEURALNETWORKS_BAD_DATA;
1825             }
1826             auto inputType = operands[inputIndexes[0]].type;
1827             std::vector<OperandType> inExpectedTypes;
1828             std::vector<OperandType> outExpectedTypes;
1829             if (inputType == OperandType::TENSOR_FLOAT16 ||
1830                 inputType == OperandType::TENSOR_FLOAT32) {
1831                 inExpectedTypes = {inputType, inputType};
1832                 outExpectedTypes = {inputType};
1833             } else {
1834                 LOG(ERROR) << "Unsupported input tensor type for operation "
1835                            << getOperationName(opType);
1836                 return ANEURALNETWORKS_BAD_DATA;
1837             }
1838             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1839                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1840             } else {
1841                 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1842             }
1843             return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1844                                                  inExpectedTypes, outputCount, outputIndexes,
1845                                                  outExpectedTypes);
1846         }
1847         case ANEURALNETWORKS_IF: {
1848             NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1849             return validateIfOperation(inputCount, inputIndexes, outputCount, outputIndexes,
1850                                        operands, helper)
1851                            ? ANEURALNETWORKS_NO_ERROR
1852                            : ANEURALNETWORKS_BAD_DATA;
1853         }
1854         case ANEURALNETWORKS_WHILE: {
1855             NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1856             return validateWhileOperation(inputCount, inputIndexes, outputCount, outputIndexes,
1857                                           operands, helper)
1858                            ? ANEURALNETWORKS_NO_ERROR
1859                            : ANEURALNETWORKS_BAD_DATA;
1860         }
1861         default: {
1862             const OperationRegistration* operationRegistration =
1863                     BuiltinOperationResolver::get()->findOperation(
1864                             static_cast<OperationType>(opType));
1865             if (operationRegistration == nullptr) {
1866                 if (0 <= opType && opType < kNumberOfOperationTypes) {
1867                     LOG(ERROR) << getOperationName(opType) << " not registered";
1868                 } else {
1869                     LOG(ERROR) << "Operation type " << opType << " out of the range [0, "
1870                                << kNumberOfOperationTypes << ")";
1871                 }
1872                 return ANEURALNETWORKS_UNEXPECTED_NULL;
1873             }
1874             if (operationRegistration->validate == nullptr) {
1875                 LOG(ERROR) << "Incomplete operation registration: " << getOperationName(opType);
1876                 return ANEURALNETWORKS_UNEXPECTED_NULL;
1877             }
1878             OperationValidationContext context(operationRegistration->name, inputCount,
1879                                                inputIndexes, outputCount, outputIndexes,
1880                                                operands.data(), halVersion);
1881             if (!operationRegistration->validate(&context)) {
1882                 LOG(ERROR) << "Validation failed for operation " << getOperationName(opType);
1883                 return ANEURALNETWORKS_BAD_DATA;
1884             }
1885             return ANEURALNETWORKS_NO_ERROR;
1886         }
1887     }
1888 }
1889 
convertResultCodeToErrorStatus(int resultCode)1890 ErrorStatus convertResultCodeToErrorStatus(int resultCode) {
1891     switch (resultCode) {
1892         case ANEURALNETWORKS_NO_ERROR:
1893             return ErrorStatus::NONE;
1894 
1895         case ANEURALNETWORKS_BAD_DATA:
1896         case ANEURALNETWORKS_UNEXPECTED_NULL:
1897             return ErrorStatus::INVALID_ARGUMENT;
1898 
1899         case ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE:
1900             return ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
1901 
1902         case ANEURALNETWORKS_UNAVAILABLE_DEVICE:
1903             return ErrorStatus::DEVICE_UNAVAILABLE;
1904 
1905         case ANEURALNETWORKS_BAD_STATE:
1906         case ANEURALNETWORKS_INCOMPLETE:
1907         case ANEURALNETWORKS_OP_FAILED:
1908         case ANEURALNETWORKS_OUT_OF_MEMORY:
1909         case ANEURALNETWORKS_UNMAPPABLE:
1910         case ANEURALNETWORKS_DEAD_OBJECT:
1911             return ErrorStatus::GENERAL_FAILURE;
1912 
1913         case ANEURALNETWORKS_MISSED_DEADLINE_TRANSIENT:
1914             return ErrorStatus::MISSED_DEADLINE_TRANSIENT;
1915         case ANEURALNETWORKS_MISSED_DEADLINE_PERSISTENT:
1916             return ErrorStatus::MISSED_DEADLINE_PERSISTENT;
1917         case ANEURALNETWORKS_RESOURCE_EXHAUSTED_TRANSIENT:
1918             return ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT;
1919         case ANEURALNETWORKS_RESOURCE_EXHAUSTED_PERSISTENT:
1920             return ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT;
1921     }
1922     LOG(ERROR) << "Unknown result code " << resultCode << " mapped to ErrorStatus::GENERAL_FAILURE";
1923     return ErrorStatus::GENERAL_FAILURE;
1924 }
1925 
convertErrorStatusToResultCode(ErrorStatus status)1926 int convertErrorStatusToResultCode(ErrorStatus status) {
1927     switch (status) {
1928         case ErrorStatus::NONE:
1929             return ANEURALNETWORKS_NO_ERROR;
1930         case ErrorStatus::DEVICE_UNAVAILABLE:
1931             return ANEURALNETWORKS_UNAVAILABLE_DEVICE;
1932         case ErrorStatus::GENERAL_FAILURE:
1933             return ANEURALNETWORKS_OP_FAILED;
1934         case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
1935             return ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
1936         case ErrorStatus::INVALID_ARGUMENT:
1937             return ANEURALNETWORKS_BAD_DATA;
1938         case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
1939             return ANEURALNETWORKS_MISSED_DEADLINE_TRANSIENT;
1940         case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
1941             return ANEURALNETWORKS_MISSED_DEADLINE_PERSISTENT;
1942         case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
1943             return ANEURALNETWORKS_RESOURCE_EXHAUSTED_TRANSIENT;
1944         case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
1945             return ANEURALNETWORKS_RESOURCE_EXHAUSTED_PERSISTENT;
1946     }
1947     LOG(ERROR) << "Unknown ErrorStatus " << toString(status)
1948                << " mapped to ANEURALNETWORKS_OP_FAILED";
1949     return ANEURALNETWORKS_OP_FAILED;
1950 }
1951 
getExecutionResult(ErrorStatus status,std::vector<OutputShape> outputShapes,Timing timing)1952 std::tuple<int, std::vector<OutputShape>, Timing> getExecutionResult(
1953         ErrorStatus status, std::vector<OutputShape> outputShapes, Timing timing) {
1954     constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
1955                                   std::numeric_limits<uint64_t>::max()};
1956     const int n = convertErrorStatusToResultCode(status);
1957     if (status != ErrorStatus::NONE && status != ErrorStatus::OUTPUT_INSUFFICIENT_SIZE &&
1958         !outputShapes.empty()) {
1959         LOG(ERROR) << "The driver returned OutputShapes when it shouldn't.";
1960         outputShapes.clear();
1961     }
1962     if (status != ErrorStatus::NONE && timing != kNoTiming) {
1963         LOG(ERROR) << "The driver returned Timing when it shouldn't.";
1964         timing = kNoTiming;
1965     }
1966     return {n, std::move(outputShapes), timing};
1967 }
1968 
combineDimensions(const std::vector<uint32_t> & lhs,const std::vector<uint32_t> & rhs)1969 std::optional<std::vector<uint32_t>> combineDimensions(const std::vector<uint32_t>& lhs,
1970                                                        const std::vector<uint32_t>& rhs) {
1971     if (rhs.empty()) return lhs;
1972     if (lhs.empty()) return rhs;
1973     if (lhs.size() != rhs.size()) {
1974         LOG(ERROR) << "Incompatible ranks: " << toString(lhs) << " and " << toString(rhs);
1975         return std::nullopt;
1976     }
1977     std::vector<uint32_t> combined = lhs;
1978     for (uint32_t i = 0; i < lhs.size(); i++) {
1979         if (lhs[i] == 0) {
1980             combined[i] = rhs[i];
1981         } else if (rhs[i] != 0 && lhs[i] != rhs[i]) {
1982             LOG(ERROR) << "Incompatible dimensions: " << toString(lhs) << " and " << toString(rhs);
1983             return std::nullopt;
1984         }
1985     }
1986     return combined;
1987 }
1988 
1989 // Capabilities::operandPerformance utilities.
1990 // The field Capabilities::operandPerformance is a vector sorted by the field
1991 // Capabilities::OperandPerformance::type.
1992 
1993 template <HalVersion version>
nonExtensionOperandPerformance(PerformanceInfo perf)1994 hidl_vec<VersionedOperandPerformance<version>> nonExtensionOperandPerformance(
1995         PerformanceInfo perf) {
1996     using OpPerf = VersionedOperandPerformance<version>;
1997 
1998     // Note: range presents enumerators in declaration order, not in numerical order.
1999     static constexpr hidl_enum_range<VersionedOperandType<version>> kOperandTypeRange;
2000 
2001     std::vector<OpPerf> ret;
2002     ret.reserve(kOperandTypeRange.end() - kOperandTypeRange.begin());
2003     for (VersionedOperandType<version> type : kOperandTypeRange) {
2004         if (static_cast<OperandType>(type) != OperandType::SUBGRAPH) {
2005             ret.push_back(OpPerf{type, perf});
2006         }
2007     }
2008     std::sort(ret.begin(), ret.end(),
2009               [](const OpPerf& a, const OpPerf& b) { return a.type < b.type; });
2010 
2011     return ret;
2012 }
2013 
2014 template hal::hidl_vec<V1_2::Capabilities::OperandPerformance>
2015 nonExtensionOperandPerformance<HalVersion::V1_2>(PerformanceInfo perf);
2016 template hal::hidl_vec<V1_3::Capabilities::OperandPerformance>
2017 nonExtensionOperandPerformance<HalVersion::V1_3>(PerformanceInfo perf);
2018 
2019 template <HalVersion version>
update(hal::hidl_vec<VersionedOperandPerformance<version>> * operandPerformance,VersionedOperandType<version> type,hal::PerformanceInfo perf)2020 void update(hal::hidl_vec<VersionedOperandPerformance<version>>* operandPerformance,
2021             VersionedOperandType<version> type, hal::PerformanceInfo perf) {
2022     CHECK(operandPerformance != nullptr);
2023     const auto it =
2024             std::lower_bound(operandPerformance->begin(), operandPerformance->end(), type,
2025                              [](const VersionedOperandPerformance<version>& perf,
2026                                 VersionedOperandType<version> type) { return perf.type < type; });
2027     CHECK(it != operandPerformance->end())
2028             << toString(type) << " not in " << toString(*operandPerformance);
2029     it->info = perf;
2030 }
2031 
update(hidl_vec<V1_2::Capabilities::OperandPerformance> * operandPerformance,V1_2::OperandType type,PerformanceInfo perf)2032 void update(hidl_vec<V1_2::Capabilities::OperandPerformance>* operandPerformance,
2033             V1_2::OperandType type, PerformanceInfo perf) {
2034     update<HalVersion::V1_2>(operandPerformance, type, perf);
2035 }
update(hidl_vec<V1_3::Capabilities::OperandPerformance> * operandPerformance,V1_3::OperandType type,PerformanceInfo perf)2036 void update(hidl_vec<V1_3::Capabilities::OperandPerformance>* operandPerformance,
2037             V1_3::OperandType type, PerformanceInfo perf) {
2038     update<HalVersion::V1_3>(operandPerformance, type, perf);
2039 }
2040 
2041 template <HalVersion version>
lookup(const hidl_vec<VersionedOperandPerformance<version>> & operandPerformance,VersionedOperandType<version> type)2042 PerformanceInfo lookup(const hidl_vec<VersionedOperandPerformance<version>>& operandPerformance,
2043                        VersionedOperandType<version> type) {
2044     const auto it = std::lower_bound(operandPerformance.begin(), operandPerformance.end(), type,
2045                                      [](const VersionedOperandPerformance<version>& perf,
2046                                         VersionedOperandType<version> type) {
2047                                          return static_cast<OperandType>(perf.type) <
2048                                                 static_cast<OperandType>(type);
2049                                      });
2050     if (it == operandPerformance.end()) {
2051         LOG(WARNING) << "No PerformanceInfo for " << toString(type);
2052         return kNoPerformanceInfo;
2053     } else {
2054         return it->info;
2055     }
2056 }
2057 
lookup(const hidl_vec<V1_2::Capabilities::OperandPerformance> & operandPerformance,V1_2::OperandType type)2058 PerformanceInfo lookup(const hidl_vec<V1_2::Capabilities::OperandPerformance>& operandPerformance,
2059                        V1_2::OperandType type) {
2060     return lookup<HalVersion::V1_2>(operandPerformance, type);
2061 }
lookup(const hidl_vec<V1_3::Capabilities::OperandPerformance> & operandPerformance,V1_3::OperandType type)2062 PerformanceInfo lookup(const hidl_vec<V1_3::Capabilities::OperandPerformance>& operandPerformance,
2063                        V1_3::OperandType type) {
2064     CHECK(type != V1_3::OperandType::SUBGRAPH)
2065             << "Use Capabilities::ifPerformance or Capabilities::whilePerformance";
2066     return lookup<HalVersion::V1_3>(operandPerformance, type);
2067 }
2068 
2069 // Versioning
2070 
2071 // In Android P, most data types are treated as having the same performance as TENSOR_QUANT8_ASYMM.
2072 // This array must be in sorted order.
2073 static const OperandType kQuantized8PerformanceConsistentWithP[] = {
2074         OperandType::INT32, OperandType::UINT32, OperandType::TENSOR_INT32, OperandType::OEM,
2075         OperandType::TENSOR_OEM_BYTE};
2076 
isQuantized8PerformanceConsistentWithP(const V1_2::Capabilities & capabilities)2077 static bool isQuantized8PerformanceConsistentWithP(const V1_2::Capabilities& capabilities) {
2078     const PerformanceInfo quantized8Performance =
2079             lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_QUANT8_ASYMM);
2080     return std::all_of(std::begin(kQuantized8PerformanceConsistentWithP),
2081                        std::end(kQuantized8PerformanceConsistentWithP),
2082                        [quantized8Performance, &capabilities](OperandType type) {
2083                            return quantized8Performance ==
2084                                   lookup(capabilities.operandPerformance,
2085                                          static_cast<V1_2::OperandType>(type));
2086                        });
2087 }
2088 
isQuantized8PerformanceConsistentWithP(const V1_3::Capabilities & capabilities)2089 static bool isQuantized8PerformanceConsistentWithP(const V1_3::Capabilities& capabilities) {
2090     const PerformanceInfo quantized8Performance =
2091             lookup(capabilities.operandPerformance, OperandType::TENSOR_QUANT8_ASYMM);
2092     return std::all_of(std::begin(kQuantized8PerformanceConsistentWithP),
2093                        std::end(kQuantized8PerformanceConsistentWithP),
2094                        [quantized8Performance, &capabilities](OperandType type) {
2095                            return quantized8Performance ==
2096                                   lookup(capabilities.operandPerformance, type);
2097                        });
2098 }
2099 
makeQuantized8PerformanceConsistentWithP(PerformanceInfo quantized8Performance)2100 static hidl_vec<V1_2::Capabilities::OperandPerformance> makeQuantized8PerformanceConsistentWithP(
2101         PerformanceInfo quantized8Performance) {
2102     hidl_vec<V1_2::Capabilities::OperandPerformance> ret(
2103             std::size(kQuantized8PerformanceConsistentWithP));
2104     std::transform(
2105             std::begin(kQuantized8PerformanceConsistentWithP),
2106             std::end(kQuantized8PerformanceConsistentWithP), ret.begin(),
2107             [quantized8Performance](OperandType type) -> V1_2::Capabilities::OperandPerformance {
2108                 return {static_cast<V1_2::OperandType>(type), quantized8Performance};
2109             });
2110     return ret;
2111 }
2112 
compliantWithV1_0(const V1_0::Capabilities &)2113 bool compliantWithV1_0(const V1_0::Capabilities&) {
2114     return true;
2115 }
2116 
compliantWithV1_0(const V1_1::Capabilities & capabilities)2117 bool compliantWithV1_0(const V1_1::Capabilities& capabilities) {
2118     return capabilities.relaxedFloat32toFloat16Performance == capabilities.float32Performance;
2119 }
2120 
compliantWithV1_0(const V1_2::Capabilities & capabilities)2121 bool compliantWithV1_0(const V1_2::Capabilities& capabilities) {
2122     const PerformanceInfo perfTensorFloat32 =
2123             lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT32);
2124     const PerformanceInfo perfFloat32 =
2125             lookup(capabilities.operandPerformance, V1_2::OperandType::FLOAT32);
2126     if (perfTensorFloat32 != perfFloat32 ||
2127         perfTensorFloat32 != capabilities.relaxedFloat32toFloat16PerformanceTensor ||
2128         perfFloat32 != capabilities.relaxedFloat32toFloat16PerformanceScalar) {
2129         return false;
2130     }
2131 
2132     return isQuantized8PerformanceConsistentWithP(capabilities);
2133 }
2134 
compliantWithV1_0(const V1_3::Capabilities & capabilities)2135 bool compliantWithV1_0(const V1_3::Capabilities& capabilities) {
2136     const PerformanceInfo perfTensorFloat32 =
2137             lookup(capabilities.operandPerformance, OperandType::TENSOR_FLOAT32);
2138     const PerformanceInfo perfFloat32 =
2139             lookup(capabilities.operandPerformance, OperandType::FLOAT32);
2140     if (perfTensorFloat32 != perfFloat32 ||
2141         perfTensorFloat32 != capabilities.relaxedFloat32toFloat16PerformanceTensor ||
2142         perfFloat32 != capabilities.relaxedFloat32toFloat16PerformanceScalar) {
2143         return false;
2144     }
2145 
2146     return isQuantized8PerformanceConsistentWithP(capabilities);
2147 }
2148 
compliantWithV1_1(const V1_0::Capabilities &)2149 bool compliantWithV1_1(const V1_0::Capabilities&) {
2150     return true;
2151 }
2152 
compliantWithV1_1(const V1_1::Capabilities &)2153 bool compliantWithV1_1(const V1_1::Capabilities&) {
2154     return true;
2155 }
2156 
compliantWithV1_1(const V1_2::Capabilities & capabilities)2157 bool compliantWithV1_1(const V1_2::Capabilities& capabilities) {
2158     if ((capabilities.relaxedFloat32toFloat16PerformanceTensor !=
2159          capabilities.relaxedFloat32toFloat16PerformanceScalar) ||
2160         (lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT32) !=
2161          lookup(capabilities.operandPerformance, V1_2::OperandType::FLOAT32))) {
2162         return false;
2163     }
2164 
2165     return isQuantized8PerformanceConsistentWithP(capabilities);
2166 }
2167 
compliantWithV1_1(const V1_3::Capabilities & capabilities)2168 bool compliantWithV1_1(const V1_3::Capabilities& capabilities) {
2169     if ((capabilities.relaxedFloat32toFloat16PerformanceTensor !=
2170          capabilities.relaxedFloat32toFloat16PerformanceScalar) ||
2171         (lookup(capabilities.operandPerformance, OperandType::TENSOR_FLOAT32) !=
2172          lookup(capabilities.operandPerformance, OperandType::FLOAT32))) {
2173         return false;
2174     }
2175 
2176     return isQuantized8PerformanceConsistentWithP(capabilities);
2177 }
2178 
compliantWithV1_2(const V1_0::Capabilities &)2179 bool compliantWithV1_2(const V1_0::Capabilities&) {
2180     return true;
2181 }
2182 
compliantWithV1_2(const V1_1::Capabilities &)2183 bool compliantWithV1_2(const V1_1::Capabilities&) {
2184     return true;
2185 }
2186 
compliantWithV1_2(const V1_2::Capabilities &)2187 bool compliantWithV1_2(const V1_2::Capabilities&) {
2188     return true;
2189 }
2190 
compliantWithV1_2(const V1_3::Capabilities &)2191 bool compliantWithV1_2(const V1_3::Capabilities&) {
2192     return true;
2193 }
2194 
compliantWithV1_3(const V1_0::Capabilities &)2195 bool compliantWithV1_3(const V1_0::Capabilities&) {
2196     return true;
2197 }
2198 
compliantWithV1_3(const V1_1::Capabilities &)2199 bool compliantWithV1_3(const V1_1::Capabilities&) {
2200     return true;
2201 }
2202 
compliantWithV1_3(const V1_2::Capabilities &)2203 bool compliantWithV1_3(const V1_2::Capabilities&) {
2204     return true;
2205 }
2206 
compliantWithV1_3(const V1_3::Capabilities &)2207 bool compliantWithV1_3(const V1_3::Capabilities&) {
2208     return true;
2209 }
2210 
convertToV1_0(V1_0::ErrorStatus status)2211 V1_0::ErrorStatus convertToV1_0(V1_0::ErrorStatus status) {
2212     return status;
2213 }
2214 
convertToV1_0(V1_3::ErrorStatus status)2215 V1_0::ErrorStatus convertToV1_0(V1_3::ErrorStatus status) {
2216     switch (status) {
2217         case V1_3::ErrorStatus::NONE:
2218             return V1_0::ErrorStatus::NONE;
2219         case V1_3::ErrorStatus::DEVICE_UNAVAILABLE:
2220             return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
2221         case V1_3::ErrorStatus::GENERAL_FAILURE:
2222             return V1_0::ErrorStatus::GENERAL_FAILURE;
2223         case V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
2224             return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
2225         case V1_3::ErrorStatus::INVALID_ARGUMENT:
2226             return V1_0::ErrorStatus::INVALID_ARGUMENT;
2227         case V1_3::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
2228             return V1_0::ErrorStatus::GENERAL_FAILURE;
2229         case V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
2230             return V1_0::ErrorStatus::GENERAL_FAILURE;
2231         case V1_3::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
2232             return V1_0::ErrorStatus::GENERAL_FAILURE;
2233         case V1_3::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
2234             return V1_0::ErrorStatus::GENERAL_FAILURE;
2235     }
2236     LOG(ERROR) << "Unknown ErrorStatus: " << toString(status) << " mapped to GENERAL_FAILURE";
2237     return V1_0::ErrorStatus::GENERAL_FAILURE;
2238 }
2239 
convertToV1_3(V1_0::ErrorStatus status)2240 V1_3::ErrorStatus convertToV1_3(V1_0::ErrorStatus status) {
2241     return static_cast<V1_3::ErrorStatus>(status);
2242 }
2243 
convertToV1_3(V1_3::ErrorStatus status)2244 V1_3::ErrorStatus convertToV1_3(V1_3::ErrorStatus status) {
2245     return status;
2246 }
2247 
uncheckedConvertToV1_0(V1_1::OperationType type)2248 static V1_0::OperationType uncheckedConvertToV1_0(V1_1::OperationType type) {
2249     return static_cast<V1_0::OperationType>(type);
2250 }
2251 
uncheckedConvertToV1_0(V1_2::OperationType type)2252 static V1_0::OperationType uncheckedConvertToV1_0(V1_2::OperationType type) {
2253     return static_cast<V1_0::OperationType>(type);
2254 }
2255 
uncheckedConvertToV1_0(V1_3::OperationType type)2256 V1_0::OperationType uncheckedConvertToV1_0(V1_3::OperationType type) {
2257     return static_cast<V1_0::OperationType>(type);
2258 }
2259 
convertToV1_1(V1_0::OperationType type)2260 static V1_1::OperationType convertToV1_1(V1_0::OperationType type) {
2261     return static_cast<V1_1::OperationType>(type);
2262 }
2263 
uncheckedConvertToV1_1(V1_2::OperationType type)2264 static V1_1::OperationType uncheckedConvertToV1_1(V1_2::OperationType type) {
2265     return static_cast<V1_1::OperationType>(type);
2266 }
2267 
uncheckedConvertToV1_1(V1_3::OperationType type)2268 V1_1::OperationType uncheckedConvertToV1_1(V1_3::OperationType type) {
2269     return static_cast<V1_1::OperationType>(type);
2270 }
2271 
convertToV1_2(V1_0::OperationType type)2272 static V1_2::OperationType convertToV1_2(V1_0::OperationType type) {
2273     return static_cast<V1_2::OperationType>(type);
2274 }
2275 
convertToV1_2(V1_1::OperationType type)2276 static V1_2::OperationType convertToV1_2(V1_1::OperationType type) {
2277     return static_cast<V1_2::OperationType>(type);
2278 }
2279 
uncheckedConvertToV1_2(V1_3::OperationType type)2280 V1_2::OperationType uncheckedConvertToV1_2(V1_3::OperationType type) {
2281     return static_cast<V1_2::OperationType>(type);
2282 }
2283 
convertToV1_3(V1_0::OperationType type)2284 static V1_3::OperationType convertToV1_3(V1_0::OperationType type) {
2285     return static_cast<V1_3::OperationType>(type);
2286 }
2287 
convertToV1_3(V1_1::OperationType type)2288 static V1_3::OperationType convertToV1_3(V1_1::OperationType type) {
2289     return static_cast<V1_3::OperationType>(type);
2290 }
2291 
convertToV1_3(V1_2::OperationType type)2292 static V1_3::OperationType convertToV1_3(V1_2::OperationType type) {
2293     return static_cast<V1_3::OperationType>(type);
2294 }
2295 
convertToV1_0(const V1_0::Capabilities & capabilities)2296 V1_0::Capabilities convertToV1_0(const V1_0::Capabilities& capabilities) {
2297     return capabilities;
2298 }
2299 
convertToV1_0(const V1_1::Capabilities & capabilities)2300 V1_0::Capabilities convertToV1_0(const V1_1::Capabilities& capabilities) {
2301     if (!compliantWithV1_0(capabilities)) {
2302         LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2303                    << " from V1_1::Capabilities to V1_0::Capabilities";
2304     }
2305     return {.float32Performance = capabilities.float32Performance,
2306             .quantized8Performance = capabilities.quantized8Performance};
2307 }
2308 
convertToV1_0(const V1_2::Capabilities & capabilities)2309 V1_0::Capabilities convertToV1_0(const V1_2::Capabilities& capabilities) {
2310     if (!compliantWithV1_0(capabilities)) {
2311         LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2312                    << " from V1_2::Capabilities to V1_0::Capabilities";
2313     }
2314     return {.float32Performance =
2315                     lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT32),
2316             .quantized8Performance = lookup(capabilities.operandPerformance,
2317                                             V1_2::OperandType::TENSOR_QUANT8_ASYMM)};
2318 }
2319 
convertToV1_0(const V1_3::Capabilities & capabilities)2320 V1_0::Capabilities convertToV1_0(const V1_3::Capabilities& capabilities) {
2321     if (!compliantWithV1_0(capabilities)) {
2322         LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2323                    << " from V1_3::Capabilities to V1_0::Capabilities";
2324     }
2325     return {.float32Performance =
2326                     lookup(capabilities.operandPerformance, OperandType::TENSOR_FLOAT32),
2327             .quantized8Performance =
2328                     lookup(capabilities.operandPerformance, OperandType::TENSOR_QUANT8_ASYMM)};
2329 }
2330 
convertToV1_1(const V1_0::Capabilities & capabilities)2331 V1_1::Capabilities convertToV1_1(const V1_0::Capabilities& capabilities) {
2332     return {.float32Performance = capabilities.float32Performance,
2333             .quantized8Performance = capabilities.quantized8Performance,
2334             .relaxedFloat32toFloat16Performance = capabilities.float32Performance};
2335 }
2336 
convertToV1_1(const V1_1::Capabilities & capabilities)2337 V1_1::Capabilities convertToV1_1(const V1_1::Capabilities& capabilities) {
2338     return capabilities;
2339 }
2340 
convertToV1_1(const V1_2::Capabilities & capabilities)2341 V1_1::Capabilities convertToV1_1(const V1_2::Capabilities& capabilities) {
2342     if (!compliantWithV1_1(capabilities)) {
2343         LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2344                    << " from V1_2::Capabilities to V1_1::Capabilities";
2345     }
2346     return {.float32Performance =
2347                     lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT32),
2348             .quantized8Performance =
2349                     lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_QUANT8_ASYMM),
2350             .relaxedFloat32toFloat16Performance =
2351                     capabilities.relaxedFloat32toFloat16PerformanceTensor};
2352 }
2353 
convertToV1_1(const V1_3::Capabilities & capabilities)2354 V1_1::Capabilities convertToV1_1(const V1_3::Capabilities& capabilities) {
2355     if (!compliantWithV1_1(capabilities)) {
2356         LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2357                    << " from V1_3::Capabilities to V1_1::Capabilities";
2358     }
2359     return {.float32Performance =
2360                     lookup(capabilities.operandPerformance, OperandType::TENSOR_FLOAT32),
2361             .quantized8Performance =
2362                     lookup(capabilities.operandPerformance, OperandType::TENSOR_QUANT8_ASYMM),
2363             .relaxedFloat32toFloat16Performance =
2364                     capabilities.relaxedFloat32toFloat16PerformanceTensor};
2365 }
2366 
convertToV1_2(const V1_0::Capabilities & capabilities)2367 V1_2::Capabilities convertToV1_2(const V1_0::Capabilities& capabilities) {
2368     V1_2::Capabilities ret = {
2369             .relaxedFloat32toFloat16PerformanceScalar = capabilities.float32Performance,
2370             .relaxedFloat32toFloat16PerformanceTensor = capabilities.float32Performance,
2371             .operandPerformance =
2372                     makeQuantized8PerformanceConsistentWithP(capabilities.quantized8Performance)};
2373     auto& opPerf = ret.operandPerformance;
2374     opPerf.resize(opPerf.size() + 2);
2375     opPerf[opPerf.size() - 2] = {V1_2::OperandType::TENSOR_FLOAT32,
2376                                  capabilities.float32Performance};
2377     opPerf[opPerf.size() - 1] = {V1_2::OperandType::FLOAT32, capabilities.float32Performance};
2378     using OperandPerformance = V1_2::Capabilities::OperandPerformance;
2379     std::sort(opPerf.begin(), opPerf.end(),
2380               [](const OperandPerformance& a, const OperandPerformance& b) {
2381                   return a.type < b.type;
2382               });
2383     return ret;
2384 }
2385 
convertToV1_2(const V1_1::Capabilities & capabilities)2386 V1_2::Capabilities convertToV1_2(const V1_1::Capabilities& capabilities) {
2387     V1_2::Capabilities ret = {.relaxedFloat32toFloat16PerformanceScalar =
2388                                       capabilities.relaxedFloat32toFloat16Performance,
2389                               .relaxedFloat32toFloat16PerformanceTensor =
2390                                       capabilities.relaxedFloat32toFloat16Performance,
2391                               .operandPerformance = makeQuantized8PerformanceConsistentWithP(
2392                                       capabilities.quantized8Performance)};
2393     auto& opPerf = ret.operandPerformance;
2394     opPerf.resize(opPerf.size() + 2);
2395     opPerf[opPerf.size() - 2] = {V1_2::OperandType::TENSOR_FLOAT32,
2396                                  capabilities.float32Performance};
2397     opPerf[opPerf.size() - 1] = {V1_2::OperandType::FLOAT32, capabilities.float32Performance};
2398     using OperandPerformance = V1_2::Capabilities::OperandPerformance;
2399     std::sort(opPerf.begin(), opPerf.end(),
2400               [](const OperandPerformance& a, const OperandPerformance& b) {
2401                   return a.type < b.type;
2402               });
2403     return ret;
2404 }
2405 
convertToV1_2(const V1_2::Capabilities & capabilities)2406 V1_2::Capabilities convertToV1_2(const V1_2::Capabilities& capabilities) {
2407     return capabilities;
2408 }
2409 
convertToV1_2(const V1_3::Capabilities & capabilities)2410 V1_2::Capabilities convertToV1_2(const V1_3::Capabilities& capabilities) {
2411     V1_2::Capabilities ret = {
2412             .relaxedFloat32toFloat16PerformanceScalar =
2413                     capabilities.relaxedFloat32toFloat16PerformanceScalar,
2414             .relaxedFloat32toFloat16PerformanceTensor =
2415                     capabilities.relaxedFloat32toFloat16PerformanceTensor,
2416     };
2417     const auto& inputOpPerf = capabilities.operandPerformance;
2418     hidl_vec<V1_3::Capabilities::OperandPerformance> opPerfSupported;
2419     opPerfSupported.resize(inputOpPerf.size());
2420     auto last =
2421             std::copy_if(inputOpPerf.begin(), inputOpPerf.end(), opPerfSupported.begin(),
2422                          [](V1_3::Capabilities::OperandPerformance opPerf) {
2423                              return validOperandType(static_cast<V1_2::OperandType>(opPerf.type));
2424                          });
2425     opPerfSupported.resize(std::distance(opPerfSupported.begin(), last));
2426 
2427     auto& convertedOpPerf = ret.operandPerformance;
2428     convertedOpPerf.resize(opPerfSupported.size());
2429     std::transform(opPerfSupported.begin(), opPerfSupported.end(), convertedOpPerf.begin(),
2430                    [](V1_3::Capabilities::OperandPerformance opPerf) {
2431                        return V1_2::Capabilities::OperandPerformance{
2432                                static_cast<V1_2::OperandType>(opPerf.type), opPerf.info};
2433                    });
2434     return ret;
2435 }
2436 
convertToV1_3(const V1_0::Capabilities & capabilities)2437 V1_3::Capabilities convertToV1_3(const V1_0::Capabilities& capabilities) {
2438     return convertToV1_3(convertToV1_2(capabilities));
2439 }
2440 
convertToV1_3(const V1_1::Capabilities & capabilities)2441 V1_3::Capabilities convertToV1_3(const V1_1::Capabilities& capabilities) {
2442     return convertToV1_3(convertToV1_2(capabilities));
2443 }
2444 
convertToV1_3(const V1_2::Capabilities & capabilities)2445 V1_3::Capabilities convertToV1_3(const V1_2::Capabilities& capabilities) {
2446     V1_3::Capabilities ret = {
2447             .relaxedFloat32toFloat16PerformanceScalar =
2448                     capabilities.relaxedFloat32toFloat16PerformanceScalar,
2449             .relaxedFloat32toFloat16PerformanceTensor =
2450                     capabilities.relaxedFloat32toFloat16PerformanceTensor,
2451             .ifPerformance = kNoPerformanceInfo,
2452             .whilePerformance = kNoPerformanceInfo,
2453     };
2454     auto& opPerf = ret.operandPerformance;
2455     opPerf.resize(capabilities.operandPerformance.size());
2456     std::transform(capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
2457                    opPerf.begin(), [](V1_2::Capabilities::OperandPerformance opPerf) {
2458                        return V1_3::Capabilities::OperandPerformance{
2459                                static_cast<V1_3::OperandType>(opPerf.type), opPerf.info};
2460                    });
2461     return ret;
2462 }
2463 
convertToV1_3(const V1_3::Capabilities & capabilities)2464 V1_3::Capabilities convertToV1_3(const V1_3::Capabilities& capabilities) {
2465     return capabilities;
2466 }
2467 
uncheckedConvertToV1_0(const V1_1::Operation & operation)2468 static V1_0::Operation uncheckedConvertToV1_0(const V1_1::Operation& operation) {
2469     return {.type = uncheckedConvertToV1_0(operation.type),
2470             .inputs = operation.inputs,
2471             .outputs = operation.outputs};
2472 }
2473 
convertToV1_1(const V1_0::Operation & operation)2474 static V1_1::Operation convertToV1_1(const V1_0::Operation& operation) {
2475     return {.type = convertToV1_1(operation.type),
2476             .inputs = operation.inputs,
2477             .outputs = operation.outputs};
2478 }
2479 
uncheckedConvertToV1_0(const hidl_vec<V1_1::Operation> & operations)2480 static hidl_vec<V1_0::Operation> uncheckedConvertToV1_0(
2481         const hidl_vec<V1_1::Operation>& operations) {
2482     hidl_vec<V1_0::Operation> result(operations.size());
2483     std::transform(
2484             operations.begin(), operations.end(), result.begin(),
2485             [](const V1_1::Operation& operation) { return uncheckedConvertToV1_0(operation); });
2486     return result;
2487 }
2488 
convertToV1_1(const hidl_vec<V1_0::Operation> & operations)2489 static hidl_vec<V1_1::Operation> convertToV1_1(const hidl_vec<V1_0::Operation>& operations) {
2490     hidl_vec<V1_1::Operation> result(operations.size());
2491     std::transform(operations.begin(), operations.end(), result.begin(),
2492                    [](const V1_0::Operation& operation) { return convertToV1_1(operation); });
2493     return result;
2494 }
2495 
compliantWithV1_0(const V1_3::Operand & operand)2496 bool compliantWithV1_0(const V1_3::Operand& operand) {
2497     return validOperandType(static_cast<V1_0::OperandType>(operand.type)) &&
2498            (nonExtensionOperandTypeIsScalar(static_cast<int>(operand.type)) ||
2499             operand.dimensions.size() != 0) &&
2500            compliantWithV1_0(operand.lifetime);
2501 }
2502 
compliantWithV1_2(const V1_3::Operand & operand)2503 bool compliantWithV1_2(const V1_3::Operand& operand) {
2504     return validOperandType(static_cast<V1_2::OperandType>(operand.type)) &&
2505            compliantWithV1_0(operand.lifetime);
2506 }
2507 
compliantWithV1_3(const V1_3::Operand & operand)2508 bool compliantWithV1_3(const V1_3::Operand& operand) {
2509     return true;
2510 }
2511 
compliantWith(HalVersion version,const V1_3::Model & model,std::set<uint32_t> * noncompliantOperations)2512 static bool compliantWith(HalVersion version, const V1_3::Model& model,
2513                           std::set<uint32_t>* noncompliantOperations) {
2514     // A boolean vector indicating whether each pool is compliant with the target HAL version.
2515     std::vector<bool> isPoolCompliant(model.pools.size(), false);
2516     std::transform(model.pools.begin(), model.pools.end(), isPoolCompliant.begin(),
2517                    [version](const hidl_memory& pool) { return validatePool(pool, version); });
2518 
2519     // A boolean vector indicating whether each operand is compliant with the target HAL version.
2520     std::vector<bool> isOperandCompliant(model.main.operands.size(), false);
2521     std::transform(model.main.operands.begin(), model.main.operands.end(),
2522                    isOperandCompliant.begin(), [&isPoolCompliant, version](const Operand& op) {
2523                        bool is_operand_compliant = false;
2524                        switch (version) {
2525                            case HalVersion::UNKNOWN:
2526                                is_operand_compliant = false;
2527                                break;
2528                            case HalVersion::V1_0:
2529                                is_operand_compliant = compliantWithV1_0(op);
2530                                break;
2531                            case HalVersion::V1_1:
2532                                // There is no V1_1::Operand -- both V1_0::Model
2533                                // and V1_1::Model use V1_0::Operand.
2534                                is_operand_compliant = compliantWithV1_0(op);
2535                                break;
2536                            case HalVersion::V1_2:
2537                                is_operand_compliant = compliantWithV1_2(op);
2538                                break;
2539                            case HalVersion::V1_3:
2540                                is_operand_compliant = compliantWithV1_3(op);
2541                                break;
2542                        }
2543                        return is_operand_compliant &&
2544                               !(op.lifetime == OperandLifeTime::CONSTANT_REFERENCE &&
2545                                 !isPoolCompliant[op.location.poolIndex]);
2546                    });
2547 
2548     auto allOperandsCompliant = [&isOperandCompliant](const hidl_vec<uint32_t>& indices) {
2549         return std::all_of(
2550                 indices.begin(), indices.end(),
2551                 [&isOperandCompliant](const uint32_t ind) { return isOperandCompliant[ind]; });
2552     };
2553 
2554     auto localValidateOperation = [&model, version, &allOperandsCompliant](const Operation& op) {
2555         if (!allOperandsCompliant(op.inputs) || !allOperandsCompliant(op.outputs)) return false;
2556         int error = validateOperation(
2557                 static_cast<int32_t>(op.type), op.inputs.size(),
2558                 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
2559                 op.outputs.size() > 0 ? op.outputs.data() : nullptr, model.main.operands, version);
2560         return error == ANEURALNETWORKS_NO_ERROR;
2561     };
2562 
2563     if (noncompliantOperations) {
2564         CHECK(noncompliantOperations->empty());
2565         for (uint32_t idx = 0; idx < model.main.operations.size(); ++idx) {
2566             if (!localValidateOperation(model.main.operations[idx])) {
2567                 noncompliantOperations->insert(idx);
2568             }
2569         }
2570         return noncompliantOperations->empty();
2571     } else {
2572         return std::all_of(model.main.operations.begin(), model.main.operations.end(),
2573                            localValidateOperation);
2574     }
2575 }
2576 
compliantWithV1_0(const V1_0::Model & model)2577 bool compliantWithV1_0(const V1_0::Model& model) {
2578     return true;
2579 }
2580 
compliantWithV1_0(const V1_1::Model & model)2581 bool compliantWithV1_0(const V1_1::Model& model) {
2582     // In addition to new enumeration values being introduced in V1_1::Model, a
2583     // new flag was introduced to indicate whether or not float32 data can be
2584     // calculated using float16 units. This 'relaxComputationFloat32toFloat16'
2585     // flag is not relevant in whether a V1_1::Model is compliant with a
2586     // V1_0::Model because all 1.0 drivers require strict calculation by default
2587     // in the P NN runtime. Even if fp16 calculations are allowed, they can
2588     // still be computed by a strict fp32 driver.
2589     return std::all_of(
2590             model.operations.begin(), model.operations.end(), [&model](const V1_1::Operation& op) {
2591                 int error = validateOperation(static_cast<int32_t>(op.type), op.inputs.size(),
2592                                               op.inputs.size() > 0 ? op.inputs.data() : nullptr,
2593                                               op.outputs.size(),
2594                                               op.outputs.size() > 0 ? op.outputs.data() : nullptr,
2595                                               convertToV1_3(model.operands), HalVersion::V1_0);
2596                 return error == ANEURALNETWORKS_NO_ERROR;
2597             });
2598 }
2599 
compliantWithV1_0(const V1_2::Model & model,std::set<uint32_t> * noncompliantOperations)2600 bool compliantWithV1_0(const V1_2::Model& model, std::set<uint32_t>* noncompliantOperations) {
2601     return compliantWith(HalVersion::V1_0, convertToV1_3(model), noncompliantOperations);
2602 }
2603 
compliantWithV1_0(const V1_3::Model & model,std::set<uint32_t> * noncompliantOperations)2604 bool compliantWithV1_0(const V1_3::Model& model, std::set<uint32_t>* noncompliantOperations) {
2605     return compliantWith(HalVersion::V1_0, model, noncompliantOperations);
2606 }
2607 
compliantWithV1_1(const V1_0::Model &)2608 bool compliantWithV1_1(const V1_0::Model&) {
2609     return true;
2610 }
2611 
compliantWithV1_1(const V1_1::Model &)2612 bool compliantWithV1_1(const V1_1::Model&) {
2613     return true;
2614 }
2615 
compliantWithV1_1(const V1_2::Model & model,std::set<uint32_t> * noncompliantOperations)2616 bool compliantWithV1_1(const V1_2::Model& model, std::set<uint32_t>* noncompliantOperations) {
2617     return compliantWith(HalVersion::V1_1, convertToV1_3(model), noncompliantOperations);
2618 }
2619 
compliantWithV1_1(const V1_3::Model & model,std::set<uint32_t> * noncompliantOperations)2620 bool compliantWithV1_1(const V1_3::Model& model, std::set<uint32_t>* noncompliantOperations) {
2621     return compliantWith(HalVersion::V1_1, model, noncompliantOperations);
2622 }
2623 
compliantWithV1_2(const V1_0::Model &)2624 bool compliantWithV1_2(const V1_0::Model&) {
2625     return true;
2626 }
2627 
compliantWithV1_2(const V1_1::Model &)2628 bool compliantWithV1_2(const V1_1::Model&) {
2629     return true;
2630 }
2631 
compliantWithV1_2(const V1_2::Model &,std::set<uint32_t> * noncompliantOperations)2632 bool compliantWithV1_2(const V1_2::Model&, std::set<uint32_t>* noncompliantOperations) {
2633     return true;
2634 }
2635 
compliantWithV1_2(const V1_3::Model & model,std::set<uint32_t> * noncompliantOperations)2636 bool compliantWithV1_2(const V1_3::Model& model, std::set<uint32_t>* noncompliantOperations) {
2637     return compliantWith(HalVersion::V1_2, model, noncompliantOperations);
2638 }
2639 
uncheckedConvertToV1_0(const V1_2::Operation & operation)2640 static V1_0::Operation uncheckedConvertToV1_0(const V1_2::Operation& operation) {
2641     return {.type = uncheckedConvertToV1_0(operation.type),
2642             .inputs = operation.inputs,
2643             .outputs = operation.outputs};
2644 }
2645 
uncheckedConvertToV1_0(const V1_3::Operation & operation)2646 static V1_0::Operation uncheckedConvertToV1_0(const V1_3::Operation& operation) {
2647     return {.type = uncheckedConvertToV1_0(operation.type),
2648             .inputs = operation.inputs,
2649             .outputs = operation.outputs};
2650 }
2651 
uncheckedConvertToV1_1(const V1_2::Operation & operation)2652 static V1_1::Operation uncheckedConvertToV1_1(const V1_2::Operation& operation) {
2653     return {.type = uncheckedConvertToV1_1(operation.type),
2654             .inputs = operation.inputs,
2655             .outputs = operation.outputs};
2656 }
2657 
uncheckedConvertToV1_1(const V1_3::Operation & operation)2658 static V1_1::Operation uncheckedConvertToV1_1(const V1_3::Operation& operation) {
2659     return {.type = uncheckedConvertToV1_1(operation.type),
2660             .inputs = operation.inputs,
2661             .outputs = operation.outputs};
2662 }
2663 
convertToV1_2(const V1_0::Operation & operation)2664 static V1_2::Operation convertToV1_2(const V1_0::Operation& operation) {
2665     return {.type = convertToV1_2(operation.type),
2666             .inputs = operation.inputs,
2667             .outputs = operation.outputs};
2668 }
2669 
convertToV1_2(const V1_1::Operation & operation)2670 static V1_2::Operation convertToV1_2(const V1_1::Operation& operation) {
2671     return {.type = convertToV1_2(operation.type),
2672             .inputs = operation.inputs,
2673             .outputs = operation.outputs};
2674 }
2675 
uncheckedConvertToV1_2(const V1_3::Operation & operation)2676 static V1_2::Operation uncheckedConvertToV1_2(const V1_3::Operation& operation) {
2677     return {.type = uncheckedConvertToV1_2(operation.type),
2678             .inputs = operation.inputs,
2679             .outputs = operation.outputs};
2680 }
2681 
convertToV1_3(const V1_0::Operation & operation)2682 static V1_3::Operation convertToV1_3(const V1_0::Operation& operation) {
2683     return {.type = convertToV1_3(operation.type),
2684             .inputs = operation.inputs,
2685             .outputs = operation.outputs};
2686 }
2687 
convertToV1_3(const V1_1::Operation & operation)2688 static V1_3::Operation convertToV1_3(const V1_1::Operation& operation) {
2689     return {.type = convertToV1_3(operation.type),
2690             .inputs = operation.inputs,
2691             .outputs = operation.outputs};
2692 }
2693 
convertToV1_3(const V1_2::Operation & operation)2694 static V1_3::Operation convertToV1_3(const V1_2::Operation& operation) {
2695     return {.type = convertToV1_3(operation.type),
2696             .inputs = operation.inputs,
2697             .outputs = operation.outputs};
2698 }
2699 
uncheckedConvertToV1_0(const hidl_vec<V1_3::Operation> & operations)2700 static hidl_vec<V1_0::Operation> uncheckedConvertToV1_0(
2701         const hidl_vec<V1_3::Operation>& operations) {
2702     hidl_vec<V1_0::Operation> result(operations.size());
2703     std::transform(
2704             operations.begin(), operations.end(), result.begin(),
2705             [](const V1_3::Operation& operation) { return uncheckedConvertToV1_0(operation); });
2706     return result;
2707 }
2708 
uncheckedConvertToV1_0(const hidl_vec<V1_2::Operation> & operations)2709 static hidl_vec<V1_0::Operation> uncheckedConvertToV1_0(
2710         const hidl_vec<V1_2::Operation>& operations) {
2711     hidl_vec<V1_0::Operation> result(operations.size());
2712     std::transform(
2713             operations.begin(), operations.end(), result.begin(),
2714             [](const V1_2::Operation& operation) { return uncheckedConvertToV1_0(operation); });
2715     return result;
2716 }
2717 
uncheckedConvertToV1_2(const hidl_vec<V1_3::Operation> & operations)2718 static hidl_vec<V1_2::Operation> uncheckedConvertToV1_2(
2719         const hidl_vec<V1_3::Operation>& operations) {
2720     hidl_vec<V1_2::Operation> result(operations.size());
2721     std::transform(
2722             operations.begin(), operations.end(), result.begin(),
2723             [](const V1_3::Operation& operation) { return uncheckedConvertToV1_2(operation); });
2724     return result;
2725 }
2726 
uncheckedConvertToV1_1(const hidl_vec<V1_2::Operation> & operations)2727 static hidl_vec<V1_1::Operation> uncheckedConvertToV1_1(
2728         const hidl_vec<V1_2::Operation>& operations) {
2729     hidl_vec<V1_1::Operation> result(operations.size());
2730     std::transform(
2731             operations.begin(), operations.end(), result.begin(),
2732             [](const V1_2::Operation& operation) { return uncheckedConvertToV1_1(operation); });
2733     return result;
2734 }
2735 
uncheckedConvertToV1_1(const hidl_vec<V1_3::Operation> & operations)2736 static hidl_vec<V1_1::Operation> uncheckedConvertToV1_1(
2737         const hidl_vec<V1_3::Operation>& operations) {
2738     hidl_vec<V1_1::Operation> result(operations.size());
2739     std::transform(
2740             operations.begin(), operations.end(), result.begin(),
2741             [](const V1_3::Operation& operation) { return uncheckedConvertToV1_1(operation); });
2742     return result;
2743 }
2744 
convertToV1_2(const hidl_vec<V1_0::Operation> & operations)2745 static hidl_vec<V1_2::Operation> convertToV1_2(const hidl_vec<V1_0::Operation>& operations) {
2746     hidl_vec<V1_2::Operation> result(operations.size());
2747     std::transform(operations.begin(), operations.end(), result.begin(),
2748                    [](const V1_0::Operation& operation) { return convertToV1_2(operation); });
2749     return result;
2750 }
2751 
convertToV1_2(const hidl_vec<V1_1::Operation> & operations)2752 static hidl_vec<V1_2::Operation> convertToV1_2(const hidl_vec<V1_1::Operation>& operations) {
2753     hidl_vec<V1_2::Operation> result(operations.size());
2754     std::transform(operations.begin(), operations.end(), result.begin(),
2755                    [](const V1_1::Operation& operation) { return convertToV1_2(operation); });
2756     return result;
2757 }
2758 
convertToV1_3(const hidl_vec<V1_0::Operation> & operations)2759 static hidl_vec<V1_3::Operation> convertToV1_3(const hidl_vec<V1_0::Operation>& operations) {
2760     hidl_vec<V1_3::Operation> result(operations.size());
2761     std::transform(operations.begin(), operations.end(), result.begin(),
2762                    [](const V1_0::Operation& operation) { return convertToV1_3(operation); });
2763     return result;
2764 }
2765 
convertToV1_3(const hidl_vec<V1_1::Operation> & operations)2766 static hidl_vec<V1_3::Operation> convertToV1_3(const hidl_vec<V1_1::Operation>& operations) {
2767     hidl_vec<V1_3::Operation> result(operations.size());
2768     std::transform(operations.begin(), operations.end(), result.begin(),
2769                    [](const V1_1::Operation& operation) { return convertToV1_3(operation); });
2770     return result;
2771 }
2772 
convertToV1_3(const hidl_vec<V1_2::Operation> & operations)2773 static hidl_vec<V1_3::Operation> convertToV1_3(const hidl_vec<V1_2::Operation>& operations) {
2774     hidl_vec<V1_3::Operation> result(operations.size());
2775     std::transform(operations.begin(), operations.end(), result.begin(),
2776                    [](const V1_2::Operation& operation) { return convertToV1_3(operation); });
2777     return result;
2778 }
2779 
compliantWithV1_0(const V1_2::OperandType & operandType)2780 static bool compliantWithV1_0(const V1_2::OperandType& operandType) {
2781     return validOperandType(static_cast<V1_0::OperandType>(operandType));
2782 }
2783 
compliantWithV1_0(const V1_3::OperandType & operandType)2784 static bool compliantWithV1_0(const V1_3::OperandType& operandType) {
2785     return validOperandType(static_cast<V1_0::OperandType>(operandType));
2786 }
2787 
compliantWithV1_2(const V1_3::OperandType & operandType)2788 static bool compliantWithV1_2(const V1_3::OperandType& operandType) {
2789     return validOperandType(static_cast<V1_2::OperandType>(operandType));
2790 }
2791 
convertToV1_0(const V1_2::OperandType & operandType)2792 V1_0::OperandType convertToV1_0(const V1_2::OperandType& operandType) {
2793     if (!compliantWithV1_0(operandType)) {
2794         LOG(ERROR) << "Upcasting non-compliant operand type " << toString(operandType)
2795                    << " from V1_2::OperandType to V1_0::OperandType";
2796     }
2797     return static_cast<V1_0::OperandType>(operandType);
2798 }
2799 
convertToV1_2(const V1_0::OperandType & operandType)2800 V1_2::OperandType convertToV1_2(const V1_0::OperandType& operandType) {
2801     return static_cast<V1_2::OperandType>(operandType);
2802 }
2803 
convertToV1_2(const V1_3::OperandType & operandType)2804 V1_2::OperandType convertToV1_2(const V1_3::OperandType& operandType) {
2805     if (!compliantWithV1_2(operandType)) {
2806         LOG(ERROR) << "Upcasting non-compliant operand type " << toString(operandType)
2807                    << " from V1_3::OperandType to V1_2::OperandType";
2808     }
2809     return static_cast<V1_2::OperandType>(operandType);
2810 }
2811 
convertToV1_0(const V1_3::OperandType & operandType)2812 V1_0::OperandType convertToV1_0(const V1_3::OperandType& operandType) {
2813     if (!compliantWithV1_0(operandType)) {
2814         LOG(ERROR) << "Upcasting non-compliant operand type " << toString(operandType)
2815                    << " from V1_3::Operand to V1_0::Operand";
2816     }
2817     return static_cast<V1_0::OperandType>(operandType);
2818 }
2819 
compliantWithV1_0(hal::V1_0::OperandLifeTime lifetime)2820 bool compliantWithV1_0(hal::V1_0::OperandLifeTime lifetime) {
2821     return true;
2822 }
2823 
compliantWithV1_0(hal::V1_3::OperandLifeTime lifetime)2824 bool compliantWithV1_0(hal::V1_3::OperandLifeTime lifetime) {
2825     return lifetime != V1_3::OperandLifeTime::SUBGRAPH;
2826 }
2827 
compliantWithV1_3(hal::V1_0::OperandLifeTime lifetime)2828 bool compliantWithV1_3(hal::V1_0::OperandLifeTime lifetime) {
2829     return true;
2830 }
2831 
compliantWithV1_3(hal::V1_3::OperandLifeTime lifetime)2832 bool compliantWithV1_3(hal::V1_3::OperandLifeTime lifetime) {
2833     return true;
2834 }
2835 
convertToV1_0(V1_0::OperandLifeTime lifetime)2836 V1_0::OperandLifeTime convertToV1_0(V1_0::OperandLifeTime lifetime) {
2837     return lifetime;
2838 }
2839 
convertToV1_0(V1_3::OperandLifeTime lifetime)2840 V1_0::OperandLifeTime convertToV1_0(V1_3::OperandLifeTime lifetime) {
2841     if (!compliantWithV1_0(lifetime)) {
2842         LOG(ERROR) << "Upcasting non-compliant lifetime " << toString(lifetime)
2843                    << " from V1_3 to V1_0";
2844     }
2845     return static_cast<V1_0::OperandLifeTime>(lifetime);
2846 }
2847 
convertToV1_3(V1_0::OperandLifeTime lifetime)2848 V1_3::OperandLifeTime convertToV1_3(V1_0::OperandLifeTime lifetime) {
2849     return static_cast<V1_3::OperandLifeTime>(lifetime);
2850 }
2851 
convertToV1_3(V1_3::OperandLifeTime lifetime)2852 V1_3::OperandLifeTime convertToV1_3(V1_3::OperandLifeTime lifetime) {
2853     return lifetime;
2854 }
2855 
convertToV1_0(const V1_2::Operand & operand)2856 V1_0::Operand convertToV1_0(const V1_2::Operand& operand) {
2857     return {.type = convertToV1_0(operand.type),
2858             .dimensions = operand.dimensions,
2859             .numberOfConsumers = operand.numberOfConsumers,
2860             .scale = operand.scale,
2861             .zeroPoint = operand.zeroPoint,
2862             .lifetime = convertToV1_0(operand.lifetime),
2863             .location = operand.location};
2864 }
2865 
convertToV1_0(const V1_3::Operand & operand)2866 V1_0::Operand convertToV1_0(const V1_3::Operand& operand) {
2867     return {.type = convertToV1_0(operand.type),
2868             .dimensions = operand.dimensions,
2869             .numberOfConsumers = operand.numberOfConsumers,
2870             .scale = operand.scale,
2871             .zeroPoint = operand.zeroPoint,
2872             .lifetime = convertToV1_0(operand.lifetime),
2873             .location = operand.location};
2874 }
2875 
convertToV1_2(const V1_0::Operand & operand)2876 V1_2::Operand convertToV1_2(const V1_0::Operand& operand) {
2877     return {.type = convertToV1_2(operand.type),
2878             .dimensions = operand.dimensions,
2879             .numberOfConsumers = operand.numberOfConsumers,
2880             .scale = operand.scale,
2881             .zeroPoint = operand.zeroPoint,
2882             .lifetime = operand.lifetime,
2883             .location = operand.location};
2884 }
2885 
convertToV1_2(const V1_3::Operand & operand)2886 V1_2::Operand convertToV1_2(const V1_3::Operand& operand) {
2887     return {.type = convertToV1_2(operand.type),
2888             .dimensions = operand.dimensions,
2889             .numberOfConsumers = operand.numberOfConsumers,
2890             .scale = operand.scale,
2891             .zeroPoint = operand.zeroPoint,
2892             .lifetime = static_cast<V1_0::OperandLifeTime>(operand.lifetime),
2893             .location = operand.location,
2894             .extraParams = operand.extraParams};
2895 }
2896 
convertToV1_3(const V1_0::Operand & operand)2897 V1_3::Operand convertToV1_3(const V1_0::Operand& operand) {
2898     return {.type = static_cast<V1_3::OperandType>(operand.type),
2899             .dimensions = operand.dimensions,
2900             .numberOfConsumers = operand.numberOfConsumers,
2901             .scale = operand.scale,
2902             .zeroPoint = operand.zeroPoint,
2903             .lifetime = convertToV1_3(operand.lifetime),
2904             .location = operand.location};
2905 }
2906 
convertToV1_3(const V1_2::Operand & operand)2907 V1_3::Operand convertToV1_3(const V1_2::Operand& operand) {
2908     return {.type = static_cast<V1_3::OperandType>(operand.type),
2909             .dimensions = operand.dimensions,
2910             .numberOfConsumers = operand.numberOfConsumers,
2911             .scale = operand.scale,
2912             .zeroPoint = operand.zeroPoint,
2913             .lifetime = convertToV1_3(operand.lifetime),
2914             .location = operand.location,
2915             .extraParams = operand.extraParams};
2916 }
2917 
convertToV1_3(const V1_3::Operand & operand)2918 V1_3::Operand convertToV1_3(const V1_3::Operand& operand) {
2919     return operand;
2920 }
2921 
convertToV1_0(const hidl_vec<V1_0::Operand> & operands)2922 hidl_vec<V1_0::Operand> convertToV1_0(const hidl_vec<V1_0::Operand>& operands) {
2923     return operands;
2924 }
2925 
convertToV1_0(const hidl_vec<V1_2::Operand> & operands)2926 hidl_vec<V1_0::Operand> convertToV1_0(const hidl_vec<V1_2::Operand>& operands) {
2927     hidl_vec<V1_0::Operand> result(operands.size());
2928     std::transform(operands.begin(), operands.end(), result.begin(),
2929                    [](const V1_2::Operand& operand) { return convertToV1_0(operand); });
2930     return result;
2931 }
2932 
convertToV1_0(const hidl_vec<V1_3::Operand> & operands)2933 hidl_vec<V1_0::Operand> convertToV1_0(const hidl_vec<V1_3::Operand>& operands) {
2934     hidl_vec<V1_0::Operand> result(operands.size());
2935     std::transform(operands.begin(), operands.end(), result.begin(),
2936                    [](const V1_3::Operand& operand) { return convertToV1_0(operand); });
2937     return result;
2938 }
2939 
convertToV1_2(const hidl_vec<V1_0::Operand> & operands)2940 hidl_vec<V1_2::Operand> convertToV1_2(const hidl_vec<V1_0::Operand>& operands) {
2941     hidl_vec<V1_2::Operand> result(operands.size());
2942     std::transform(operands.begin(), operands.end(), result.begin(),
2943                    [](const V1_0::Operand& operand) { return convertToV1_2(operand); });
2944     return result;
2945 }
2946 
convertToV1_2(const hidl_vec<V1_2::Operand> & operands)2947 hidl_vec<V1_2::Operand> convertToV1_2(const hidl_vec<V1_2::Operand>& operands) {
2948     return operands;
2949 }
2950 
convertToV1_2(const hidl_vec<V1_3::Operand> & operands)2951 hidl_vec<V1_2::Operand> convertToV1_2(const hidl_vec<V1_3::Operand>& operands) {
2952     hidl_vec<V1_2::Operand> result(operands.size());
2953     std::transform(operands.begin(), operands.end(), result.begin(),
2954                    [](const V1_3::Operand& operand) { return convertToV1_2(operand); });
2955     return result;
2956 }
2957 
convertToV1_3(const hidl_vec<V1_0::Operand> & operands)2958 hidl_vec<V1_3::Operand> convertToV1_3(const hidl_vec<V1_0::Operand>& operands) {
2959     hidl_vec<V1_3::Operand> result(operands.size());
2960     std::transform(operands.begin(), operands.end(), result.begin(),
2961                    [](const V1_0::Operand& operand) { return convertToV1_3(operand); });
2962     return result;
2963 }
2964 
convertToV1_3(const hidl_vec<V1_2::Operand> & operands)2965 hidl_vec<V1_3::Operand> convertToV1_3(const hidl_vec<V1_2::Operand>& operands) {
2966     hidl_vec<V1_3::Operand> result(operands.size());
2967     std::transform(operands.begin(), operands.end(), result.begin(),
2968                    [](const V1_2::Operand& operand) { return convertToV1_3(operand); });
2969     return result;
2970 }
2971 
convertToV1_3(const hidl_vec<V1_3::Operand> & operands)2972 hidl_vec<V1_3::Operand> convertToV1_3(const hidl_vec<V1_3::Operand>& operands) {
2973     return operands;
2974 }
2975 
convertToV1_0(const V1_0::Model & model)2976 V1_0::Model convertToV1_0(const V1_0::Model& model) {
2977     return model;
2978 }
2979 
convertToV1_0(const V1_1::Model & model)2980 V1_0::Model convertToV1_0(const V1_1::Model& model) {
2981     if (!compliantWithV1_0(model)) {
2982         LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
2983                    << " from V1_1::Model to V1_0::Model";
2984     }
2985     return {.operands = model.operands,
2986             .operations = uncheckedConvertToV1_0(model.operations),
2987             .inputIndexes = model.inputIndexes,
2988             .outputIndexes = model.outputIndexes,
2989             .operandValues = model.operandValues,
2990             .pools = model.pools};
2991 }
2992 
convertToV1_0(const V1_2::Model & model)2993 V1_0::Model convertToV1_0(const V1_2::Model& model) {
2994     if (!compliantWithV1_0(model)) {
2995         LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
2996                    << " from V1_2::Model to V1_0::Model";
2997     }
2998     return {.operands = convertToV1_0(model.operands),
2999             .operations = uncheckedConvertToV1_0(model.operations),
3000             .inputIndexes = model.inputIndexes,
3001             .outputIndexes = model.outputIndexes,
3002             .operandValues = model.operandValues,
3003             .pools = model.pools};
3004 }
3005 
convertToV1_0(const V1_3::Model & model)3006 V1_0::Model convertToV1_0(const V1_3::Model& model) {
3007     if (!compliantWithV1_0(model)) {
3008         LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
3009                    << " from V1_3::Model to V1_0::Model";
3010     }
3011     return {.operands = convertToV1_0(model.main.operands),
3012             .operations = uncheckedConvertToV1_0(model.main.operations),
3013             .inputIndexes = model.main.inputIndexes,
3014             .outputIndexes = model.main.outputIndexes,
3015             .operandValues = model.operandValues,
3016             .pools = model.pools};
3017 }
3018 
convertToV1_1(const V1_0::Model & model)3019 V1_1::Model convertToV1_1(const V1_0::Model& model) {
3020     return {.operands = model.operands,
3021             .operations = convertToV1_1(model.operations),
3022             .inputIndexes = model.inputIndexes,
3023             .outputIndexes = model.outputIndexes,
3024             .operandValues = model.operandValues,
3025             .pools = model.pools,
3026             .relaxComputationFloat32toFloat16 = false};
3027 }
3028 
convertToV1_1(const V1_1::Model & model)3029 V1_1::Model convertToV1_1(const V1_1::Model& model) {
3030     return model;
3031 }
3032 
convertToV1_1(const V1_2::Model & model)3033 V1_1::Model convertToV1_1(const V1_2::Model& model) {
3034     if (!compliantWithV1_1(model)) {
3035         LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
3036                    << " from V1_2::Model to V1_1::Model";
3037     }
3038     return {.operands = convertToV1_0(model.operands),  // Operands in 1.1 and 1.0 are identical.
3039             .operations = uncheckedConvertToV1_1(model.operations),
3040             .inputIndexes = model.inputIndexes,
3041             .outputIndexes = model.outputIndexes,
3042             .operandValues = model.operandValues,
3043             .pools = model.pools,
3044             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16};
3045 }
3046 
convertToV1_1(const V1_3::Model & model)3047 V1_1::Model convertToV1_1(const V1_3::Model& model) {
3048     if (!compliantWithV1_1(model)) {
3049         LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
3050                    << " from V1_3::Model to V1_1::Model";
3051     }
3052     return {// Operands in 1.1 and 1.0 are identical.
3053             .operands = convertToV1_0(model.main.operands),
3054             .operations = uncheckedConvertToV1_1(model.main.operations),
3055             .inputIndexes = model.main.inputIndexes,
3056             .outputIndexes = model.main.outputIndexes,
3057             .operandValues = model.operandValues,
3058             .pools = model.pools,
3059             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16};
3060 }
3061 
convertToV1_2(const V1_0::Model & model)3062 V1_2::Model convertToV1_2(const V1_0::Model& model) {
3063     return {.operands = convertToV1_2(model.operands),
3064             .operations = convertToV1_2(model.operations),
3065             .inputIndexes = model.inputIndexes,
3066             .outputIndexes = model.outputIndexes,
3067             .operandValues = model.operandValues,
3068             .pools = model.pools,
3069             .relaxComputationFloat32toFloat16 = false};
3070 }
3071 
convertToV1_2(const V1_1::Model & model)3072 V1_2::Model convertToV1_2(const V1_1::Model& model) {
3073     return {.operands = convertToV1_2(model.operands),
3074             .operations = convertToV1_2(model.operations),
3075             .inputIndexes = model.inputIndexes,
3076             .outputIndexes = model.outputIndexes,
3077             .operandValues = model.operandValues,
3078             .pools = model.pools,
3079             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16};
3080 }
3081 
convertToV1_2(const V1_2::Model & model)3082 V1_2::Model convertToV1_2(const V1_2::Model& model) {
3083     return model;
3084 }
3085 
convertToV1_2(const V1_3::Model & model)3086 V1_2::Model convertToV1_2(const V1_3::Model& model) {
3087     if (!compliantWithV1_2(model)) {
3088         LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
3089                    << " from V1_3::Model to V1_2::Model";
3090     }
3091     return {.operands = convertToV1_2(model.main.operands),
3092             .operations = uncheckedConvertToV1_2(model.main.operations),
3093             .inputIndexes = model.main.inputIndexes,
3094             .outputIndexes = model.main.outputIndexes,
3095             .operandValues = model.operandValues,
3096             .pools = model.pools,
3097             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
3098             .extensionNameToPrefix = model.extensionNameToPrefix};
3099 }
3100 
convertToV1_3(const V1_0::Model & model)3101 V1_3::Model convertToV1_3(const V1_0::Model& model) {
3102     return {.main = {.operands = convertToV1_3(model.operands),
3103                      .operations = convertToV1_3(model.operations),
3104                      .inputIndexes = model.inputIndexes,
3105                      .outputIndexes = model.outputIndexes},
3106             .operandValues = model.operandValues,
3107             .pools = model.pools,
3108             .relaxComputationFloat32toFloat16 = false};
3109 }
3110 
convertToV1_3(const V1_1::Model & model)3111 V1_3::Model convertToV1_3(const V1_1::Model& model) {
3112     return {.main = {.operands = convertToV1_3(model.operands),
3113                      .operations = convertToV1_3(model.operations),
3114                      .inputIndexes = model.inputIndexes,
3115                      .outputIndexes = model.outputIndexes},
3116             .operandValues = model.operandValues,
3117             .pools = model.pools,
3118             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16};
3119 }
3120 
convertToV1_3(const V1_2::Model & model)3121 V1_3::Model convertToV1_3(const V1_2::Model& model) {
3122     return {.main = {.operands = convertToV1_3(model.operands),
3123                      .operations = convertToV1_3(model.operations),
3124                      .inputIndexes = model.inputIndexes,
3125                      .outputIndexes = model.outputIndexes},
3126             .operandValues = model.operandValues,
3127             .pools = model.pools,
3128             .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
3129             .extensionNameToPrefix = model.extensionNameToPrefix};
3130 }
3131 
convertToV1_3(const V1_3::Model & model)3132 V1_3::Model convertToV1_3(const V1_3::Model& model) {
3133     return model;
3134 }
3135 
compliantWithV1_0(const V1_0::Request & request)3136 bool compliantWithV1_0(const V1_0::Request& request) {
3137     return true;
3138 }
3139 
compliantWithV1_0(const V1_3::Request & request)3140 bool compliantWithV1_0(const V1_3::Request& request) {
3141     return std::all_of(request.pools.begin(), request.pools.end(), [](const auto& pool) {
3142         if (pool.getDiscriminator() != V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) {
3143             return false;
3144         }
3145         const auto& name = pool.hidlMemory().name();
3146         return name == "ashmem" || name == "mmap_fd";
3147     });
3148 }
3149 
compliantWithV1_2(const V1_3::Request & request)3150 bool compliantWithV1_2(const V1_3::Request& request) {
3151     return std::all_of(request.pools.begin(), request.pools.end(), [](const auto& pool) {
3152         if (pool.getDiscriminator() != V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) {
3153             return false;
3154         }
3155         const auto& name = pool.hidlMemory().name();
3156         return name == "ashmem" || name == "mmap_fd" || name == "hardware_buffer_blob" ||
3157                name == "hardware_buffer";
3158     });
3159 }
3160 
convertToV1_0(const V1_3::Request::MemoryPool & pool)3161 static hidl_memory convertToV1_0(const V1_3::Request::MemoryPool& pool) {
3162     switch (pool.getDiscriminator()) {
3163         case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory:
3164             return pool.hidlMemory();
3165         case V1_3::Request::MemoryPool::hidl_discriminator::token:
3166             return hidl_memory{};
3167     }
3168 }
3169 
convertToV1_3(const hidl_memory & pool)3170 static V1_3::Request::MemoryPool convertToV1_3(const hidl_memory& pool) {
3171     V1_3::Request::MemoryPool ret;
3172     ret.hidlMemory(pool);
3173     return ret;
3174 }
3175 
convertToV1_0(const V1_0::Request & request)3176 V1_0::Request convertToV1_0(const V1_0::Request& request) {
3177     return request;
3178 }
3179 
uncheckedConvertToV1_0(const V1_3::Request & request)3180 static V1_0::Request uncheckedConvertToV1_0(const V1_3::Request& request) {
3181     hidl_vec<hidl_memory> pools(request.pools.size());
3182     std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
3183                    [](const auto& pool) { return convertToV1_0(pool); });
3184     return {.inputs = request.inputs, .outputs = request.outputs, .pools = std::move(pools)};
3185 }
3186 
convertToV1_0(const V1_3::Request & request)3187 V1_0::Request convertToV1_0(const V1_3::Request& request) {
3188     if (!compliantWithV1_0(request)) {
3189         LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
3190                    << " from V1_3::Request to V1_0::Request of version 1.0";
3191     }
3192     return uncheckedConvertToV1_0(request);
3193 }
3194 
convertToV1_2(const V1_3::Request & request)3195 V1_0::Request convertToV1_2(const V1_3::Request& request) {
3196     if (!compliantWithV1_2(request)) {
3197         LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
3198                    << " from V1_3::Request to V1_0::Request of version 1.2";
3199     }
3200     return uncheckedConvertToV1_0(request);
3201 }
3202 
convertToV1_3(const V1_0::Request & request)3203 V1_3::Request convertToV1_3(const V1_0::Request& request) {
3204     hidl_vec<V1_3::Request::MemoryPool> pools(request.pools.size());
3205     std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
3206                    [](const auto& pool) { return convertToV1_3(pool); });
3207     return {.inputs = request.inputs, .outputs = request.outputs, .pools = std::move(pools)};
3208 }
3209 
convertToV1_3(const V1_3::Request & request)3210 V1_3::Request convertToV1_3(const V1_3::Request& request) {
3211     return request;
3212 }
3213 
syncWait(int fd,int timeout)3214 FenceState syncWait(int fd, int timeout) {
3215     // This implementation is directly based on the ::sync_wait() implementation.
3216 
3217     struct pollfd fds;
3218     int ret;
3219 
3220     if (fd < 0) {
3221         errno = EINVAL;
3222         return FenceState::UNKNOWN;
3223     }
3224 
3225     fds.fd = fd;
3226     fds.events = POLLIN;
3227 
3228     do {
3229         ret = poll(&fds, 1, timeout);
3230         if (ret > 0) {
3231             if (fds.revents & POLLNVAL) {
3232                 errno = EINVAL;
3233                 return FenceState::UNKNOWN;
3234             }
3235             if (fds.revents & POLLERR) {
3236                 errno = EINVAL;
3237                 return FenceState::ERROR;
3238             }
3239             return FenceState::SIGNALED;
3240         } else if (ret == 0) {
3241             errno = ETIME;
3242             return FenceState::ACTIVE;
3243         }
3244     } while (ret == -1 && (errno == EINTR || errno == EAGAIN));
3245 
3246     return FenceState::UNKNOWN;
3247 }
3248 
3249 #ifdef NN_DEBUGGABLE
getProp(const char * str,uint32_t defaultValue)3250 uint32_t getProp(const char* str, uint32_t defaultValue) {
3251     const std::string propStr = android::base::GetProperty(str, "");
3252     if (propStr.size() > 0) {
3253         return std::stoi(propStr);
3254     } else {
3255         return defaultValue;
3256     }
3257 }
3258 #endif  // NN_DEBUGGABLE
3259 
3260 }  // namespace nn
3261 }  // namespace android
3262