1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Utils"
18
19 #include "Utils.h"
20
21 #include <android-base/logging.h>
22 #include <android-base/properties.h>
23 #include <android-base/strings.h>
24 #include <errno.h>
25 #include <poll.h>
26
27 #include <algorithm>
28 #include <cfloat>
29 #include <functional>
30 #include <iostream>
31 #include <limits>
32 #include <numeric>
33 #include <set>
34 #include <string>
35 #include <tuple>
36 #include <unordered_map>
37 #include <utility>
38 #include <vector>
39
40 #include "ControlFlow.h"
41 #include "NeuralNetworks.h"
42 #include "NeuralNetworksOEM.h"
43 #include "OperationResolver.h"
44 #include "ValidateHal.h"
45
46 namespace android {
47 namespace nn {
48
49 using namespace hal;
50
51 constexpr PerformanceInfo kNoPerformanceInfo = {.execTime = FLT_MAX, .powerUsage = FLT_MAX};
52
53 const char kVLogPropKey[] = "debug.nn.vlog";
54 int vLogMask = ~0;
55
56 // Split the space separated list of tags from verbose log setting and build the
57 // logging mask from it. note that '1' and 'all' are special cases to enable all
58 // verbose logging.
59 //
60 // NN API verbose logging setting comes from system property debug.nn.vlog.
61 // Example:
62 // setprop debug.nn.vlog 1 : enable all logging tags.
63 // setprop debug.nn.vlog "model compilation" : only enable logging for MODEL and
64 // COMPILATION tags.
initVLogMask()65 void initVLogMask() {
66 vLogMask = 0;
67 const std::string vLogSetting = android::base::GetProperty(kVLogPropKey, "");
68 if (vLogSetting.empty()) {
69 return;
70 }
71
72 std::unordered_map<std::string, int> vLogFlags = {{"1", -1},
73 {"all", -1},
74 {"model", MODEL},
75 {"compilation", COMPILATION},
76 {"execution", EXECUTION},
77 {"cpuexe", CPUEXE},
78 {"manager", MANAGER},
79 {"driver", DRIVER},
80 {"memory", MEMORY}};
81
82 std::vector<std::string> elements = android::base::Split(vLogSetting, " ,:");
83 for (const auto& elem : elements) {
84 const auto& flag = vLogFlags.find(elem);
85 if (flag == vLogFlags.end()) {
86 LOG(ERROR) << "Unknown trace flag: " << elem;
87 continue;
88 }
89
90 if (flag->second == -1) {
91 // -1 is used for the special values "1" and "all" that enable all
92 // tracing.
93 vLogMask = ~0;
94 return;
95 } else {
96 vLogMask |= 1 << flag->second;
97 }
98 }
99 }
100
makeDeadline(uint64_t duration)101 Deadline makeDeadline(uint64_t duration) {
102 const auto maxTime = Deadline::max();
103 const auto currentTime = std::chrono::steady_clock::now();
104
105 // Create Deadline. If there would be an overflow, use the max value.
106 const uint64_t remainingNanoseconds =
107 std::chrono::duration_cast<std::chrono::nanoseconds>(maxTime - currentTime).count();
108 if (duration > remainingNanoseconds) {
109 return maxTime;
110 }
111 return currentTime + std::chrono::nanoseconds{duration};
112 }
113
makeDeadline(std::optional<uint64_t> duration)114 std::optional<Deadline> makeDeadline(std::optional<uint64_t> duration) {
115 return duration.has_value() ? makeDeadline(*duration) : std::optional<Deadline>{};
116 }
117
getMaxNanosecondsSinceEpoch()118 static uint64_t getMaxNanosecondsSinceEpoch() {
119 const auto maxTime =
120 std::chrono::time_point<std::chrono::steady_clock, std::chrono::nanoseconds>::max();
121 return maxTime.time_since_epoch().count();
122 }
123
makeDeadline(const OptionalTimePoint & timePoint)124 std::optional<Deadline> makeDeadline(const OptionalTimePoint& timePoint) {
125 using Discriminator = hal::OptionalTimePoint::hidl_discriminator;
126 if (timePoint.getDiscriminator() == Discriminator::none) {
127 return std::nullopt;
128 }
129 const uint64_t nanosecondsSinceEpoch = timePoint.nanosecondsSinceEpoch();
130 const uint64_t maxNanosecondsSinceEpoch = getMaxNanosecondsSinceEpoch();
131
132 // Clamp time point to max.
133 if (nanosecondsSinceEpoch >= maxNanosecondsSinceEpoch) {
134 return Deadline::max();
135 }
136
137 // Return provided time point.
138 return Deadline{std::chrono::nanoseconds{nanosecondsSinceEpoch}};
139 }
140
hasDeadlinePassed(const std::optional<Deadline> & deadline)141 bool hasDeadlinePassed(const std::optional<Deadline>& deadline) {
142 if (!deadline.has_value()) {
143 return false;
144 }
145 return std::chrono::steady_clock::now() >= *deadline;
146 }
147
makeTimePoint(const Deadline & deadline)148 static OptionalTimePoint makeTimePoint(const Deadline& deadline) {
149 const auto timeSinceEpoch = deadline.time_since_epoch();
150 const uint64_t nanosecondsSinceEpoch =
151 std::chrono::duration_cast<std::chrono::nanoseconds>(timeSinceEpoch).count();
152 OptionalTimePoint ret;
153 ret.nanosecondsSinceEpoch(nanosecondsSinceEpoch);
154 return ret;
155 }
156
makeTimePoint(const std::optional<Deadline> & deadline)157 OptionalTimePoint makeTimePoint(const std::optional<Deadline>& deadline) {
158 return deadline.has_value() ? makeTimePoint(*deadline) : OptionalTimePoint{};
159 }
160
isExtensionOperandType(int32_t type)161 static bool isExtensionOperandType(int32_t type) {
162 return static_cast<uint32_t>(type) > static_cast<uint32_t>(OperandTypeRange::BASE_MAX);
163 }
164
isExtensionOperationType(ANeuralNetworksOperationType type)165 static bool isExtensionOperationType(ANeuralNetworksOperationType type) {
166 return static_cast<uint32_t>(type) > static_cast<uint32_t>(OperationTypeRange::BASE_MAX);
167 }
168
isExtensionOperandType(OperandType type)169 bool isExtensionOperandType(OperandType type) {
170 return isExtensionOperandType(static_cast<int32_t>(type));
171 }
172
isExtensionOperationType(OperationType type)173 bool isExtensionOperationType(OperationType type) {
174 return isExtensionOperationType(static_cast<int32_t>(type));
175 }
176
177 namespace {
178
179 template <typename EntryType, uint32_t entryCount, uint32_t entryCountOEM>
tableLookup(const EntryType (& table)[entryCount],const EntryType (& tableOEM)[entryCountOEM],uint32_t code)180 EntryType tableLookup(const EntryType (&table)[entryCount],
181 const EntryType (&tableOEM)[entryCountOEM], uint32_t code) {
182 if (code < entryCount) {
183 return table[code];
184 } else if (code >= kOEMCodeBase && (code - kOEMCodeBase) < entryCountOEM) {
185 return tableOEM[code - kOEMCodeBase];
186 } else {
187 nnAssert(!"tableLookup: bad code");
188 return EntryType();
189 }
190 }
191
192 class OperationValidationContext : public IOperationValidationContext {
193 DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext);
194
195 public:
OperationValidationContext(const char * operationName,uint32_t inputCount,const uint32_t * inputIndexes,uint32_t outputCount,const uint32_t * outputIndexes,const Operand * operands,HalVersion halVersion)196 OperationValidationContext(const char* operationName, uint32_t inputCount,
197 const uint32_t* inputIndexes, uint32_t outputCount,
198 const uint32_t* outputIndexes, const Operand* operands,
199 HalVersion halVersion)
200 : operationName(operationName),
201 inputCount(inputCount),
202 inputIndexes(inputIndexes),
203 outputCount(outputCount),
204 outputIndexes(outputIndexes),
205 operands(operands),
206 halVersion(halVersion) {}
207
208 const char* getOperationName() const override;
209 HalVersion getHalVersion() const override;
210
211 uint32_t getNumInputs() const override;
212 OperandType getInputType(uint32_t index) const override;
213 Shape getInputShape(uint32_t index) const override;
214 const OperandExtraParams getInputExtraParams(uint32_t index) const override;
215
216 uint32_t getNumOutputs() const override;
217 OperandType getOutputType(uint32_t index) const override;
218 Shape getOutputShape(uint32_t index) const override;
219
220 private:
221 const Operand* getInputOperand(uint32_t index) const;
222 const Operand* getOutputOperand(uint32_t index) const;
223
224 const char* operationName;
225 uint32_t inputCount;
226 const uint32_t* inputIndexes;
227 uint32_t outputCount;
228 const uint32_t* outputIndexes;
229 const Operand* operands;
230 HalVersion halVersion;
231 };
232
getOperationName() const233 const char* OperationValidationContext::getOperationName() const {
234 return operationName;
235 }
236
getHalVersion() const237 HalVersion OperationValidationContext::getHalVersion() const {
238 return halVersion;
239 }
240
getInputOperand(uint32_t index) const241 const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
242 CHECK(index < static_cast<uint32_t>(inputCount));
243 return &operands[inputIndexes[index]];
244 }
245
getOutputOperand(uint32_t index) const246 const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const {
247 CHECK(index < static_cast<uint32_t>(outputCount));
248 return &operands[outputIndexes[index]];
249 }
250
getNumInputs() const251 uint32_t OperationValidationContext::getNumInputs() const {
252 return inputCount;
253 }
254
getNumOutputs() const255 uint32_t OperationValidationContext::getNumOutputs() const {
256 return outputCount;
257 }
258
getInputType(uint32_t index) const259 OperandType OperationValidationContext::getInputType(uint32_t index) const {
260 return getInputOperand(index)->type;
261 }
262
getInputShape(uint32_t index) const263 Shape OperationValidationContext::getInputShape(uint32_t index) const {
264 const Operand* operand = getInputOperand(index);
265 return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
266 operand->extraParams};
267 }
268
getInputExtraParams(uint32_t index) const269 const OperandExtraParams OperationValidationContext::getInputExtraParams(uint32_t index) const {
270 return getInputOperand(index)->extraParams;
271 }
272
getOutputType(uint32_t index) const273 OperandType OperationValidationContext::getOutputType(uint32_t index) const {
274 return getOutputOperand(index)->type;
275 }
276
getOutputShape(uint32_t index) const277 Shape OperationValidationContext::getOutputShape(uint32_t index) const {
278 const Operand* operand = getOutputOperand(index);
279 return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
280 operand->extraParams};
281 }
282
283 }; // anonymous namespace
284
285 #define COUNT(X) (sizeof(X) / sizeof(X[0]))
286
getOperandTypeName(OperandType type)287 std::string getOperandTypeName(OperandType type) {
288 return toString(type);
289 }
290
getOperationName(uint32_t code)291 static std::string getOperationName(uint32_t code) {
292 return getOperationName(static_cast<OperationType>(code));
293 }
294
getOperationName(OperationType type)295 std::string getOperationName(OperationType type) {
296 return toString(type);
297 }
298
299 const uint32_t kSizeOfDataType[]{
300 4, // ANEURALNETWORKS_FLOAT32
301 4, // ANEURALNETWORKS_INT32
302 4, // ANEURALNETWORKS_UINT32
303 4, // ANEURALNETWORKS_TENSOR_FLOAT32
304 4, // ANEURALNETWORKS_TENSOR_INT32
305 1, // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
306 1, // ANEURALNETWORKS_BOOL
307 2, // ANEURALNETWORKS_TENSOR_QUANT16_SYMM
308 2, // ANEURALNETWORKS_TENSOR_FLOAT16
309 1, // ANEURALNETWORKS_TENSOR_BOOL8
310 2, // ANEURALNETWORKS_FLOAT16
311 1, // ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL
312 2, // ANEURALNETWORKS_TENSOR_QUANT16_ASYMM
313 1, // ANEURALNETWORKS_TENSOR_QUANT8_SYMM
314 1, // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED
315 0, // ANEURALNETWORKS_MODEL
316 };
317
318 static_assert(COUNT(kSizeOfDataType) == kNumberOfDataTypes, "kSizeOfDataType is incorrect");
319
320 const bool kScalarDataType[]{
321 true, // ANEURALNETWORKS_FLOAT32
322 true, // ANEURALNETWORKS_INT32
323 true, // ANEURALNETWORKS_UINT32
324 false, // ANEURALNETWORKS_TENSOR_FLOAT32
325 false, // ANEURALNETWORKS_TENSOR_INT32
326 false, // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM
327 true, // ANEURALNETWORKS_BOOL
328 false, // ANEURALNETWORKS_TENSOR_QUANT16_SYMM
329 false, // ANEURALNETWORKS_TENSOR_FLOAT16
330 false, // ANEURALNETWORKS_TENSOR_BOOL8
331 true, // ANEURALNETWORKS_FLOAT16
332 false, // ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL
333 false, // ANEURALNETWORKS_TENSOR_QUANT16_ASYMM
334 false, // ANEURALNETWORKS_TENSOR_QUANT8_SYMM
335 false, // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED
336 true, // ANEURALNETWORKS_MODEL
337 };
338
339 static_assert(COUNT(kScalarDataType) == kNumberOfDataTypes, "kScalarDataType is incorrect");
340
341 const uint32_t kSizeOfDataTypeOEM[]{
342 0, // ANEURALNETWORKS_OEM
343 1, // ANEURALNETWORKS_TENSOR_OEM_BYTE
344 };
345
346 static_assert(COUNT(kSizeOfDataTypeOEM) == kNumberOfDataTypesOEM,
347 "kSizeOfDataTypeOEM is incorrect");
348
349 const bool kScalarDataTypeOEM[]{
350 true, // ANEURALNETWORKS_OEM
351 false, // ANEURALNETWORKS_TENSOR_OEM_BYTE
352 };
353
354 static_assert(COUNT(kScalarDataTypeOEM) == kNumberOfDataTypesOEM,
355 "kScalarDataTypeOEM is incorrect");
356
nonExtensionOperandTypeIsScalar(int type)357 bool nonExtensionOperandTypeIsScalar(int type) {
358 CHECK(!isExtensionOperandType(type)) << "Extension operand types are not supported";
359 return tableLookup(kScalarDataType, kScalarDataTypeOEM, type);
360 }
361
nonExtensionOperandSizeOfData(OperandType type,const std::vector<uint32_t> & dimensions)362 uint32_t nonExtensionOperandSizeOfData(OperandType type, const std::vector<uint32_t>& dimensions) {
363 CHECK(!isExtensionOperandType(type)) << "Size of extension operand data is unknown";
364 int n = static_cast<int>(type);
365 uint32_t sizeOfElement = tableLookup(kSizeOfDataType, kSizeOfDataTypeOEM, n);
366 return tableLookup(kScalarDataType, kScalarDataTypeOEM, n)
367 ? sizeOfElement
368 : sizeOfTensorData(sizeOfElement, dimensions);
369 }
370
371 // Returns a pair of {false, size} on success, {true, 0} if size overflows uint32_t.
sizeOfTensorDataHelper(uint32_t sizeOfElement,const std::vector<uint32_t> & dimensions)372 static std::pair<bool, uint32_t> sizeOfTensorDataHelper(uint32_t sizeOfElement,
373 const std::vector<uint32_t>& dimensions) {
374 if (dimensions.empty()) {
375 return {false, 0};
376 }
377 uint64_t size = static_cast<uint64_t>(sizeOfElement);
378 constexpr uint64_t kMaxSize = static_cast<uint64_t>(std::numeric_limits<uint32_t>::max());
379 for (uint32_t d : dimensions) {
380 size *= d;
381 if (size > kMaxSize) return {true, 0};
382 }
383 return {false, static_cast<uint32_t>(size)};
384 }
385
sizeOfTensorData(uint32_t sizeOfElement,const std::vector<uint32_t> & dimensions)386 uint32_t sizeOfTensorData(uint32_t sizeOfElement, const std::vector<uint32_t>& dimensions) {
387 const auto [overflow, size] = sizeOfTensorDataHelper(sizeOfElement, dimensions);
388 CHECK(!overflow);
389 return size;
390 }
391
nonExtensionOperandSizeOfDataOverflowsUInt32(hal::OperandType type,const std::vector<uint32_t> & dimensions)392 bool nonExtensionOperandSizeOfDataOverflowsUInt32(hal::OperandType type,
393 const std::vector<uint32_t>& dimensions) {
394 CHECK(!isExtensionOperandType(type)) << "Size of extension operand data is unknown";
395 int n = static_cast<int>(type);
396 uint32_t sizeOfElement = tableLookup(kSizeOfDataType, kSizeOfDataTypeOEM, n);
397 return tableLookup(kScalarDataType, kScalarDataTypeOEM, n)
398 ? false
399 : sizeOfTensorDataOverflowsUInt32(sizeOfElement, dimensions);
400 }
401
sizeOfTensorDataOverflowsUInt32(uint32_t sizeOfElement,const std::vector<uint32_t> & dimensions)402 bool sizeOfTensorDataOverflowsUInt32(uint32_t sizeOfElement,
403 const std::vector<uint32_t>& dimensions) {
404 return sizeOfTensorDataHelper(sizeOfElement, dimensions).first;
405 }
406
tensorHasUnspecifiedDimensions(int type,const uint32_t * dim,uint32_t dimCount)407 bool tensorHasUnspecifiedDimensions(int type, const uint32_t* dim, uint32_t dimCount) {
408 if (!isExtensionOperandType(type)) {
409 CHECK(!nonExtensionOperandTypeIsScalar(type))
410 << "A scalar type can never have unspecified dimensions";
411 }
412 return dimCount == 0 || std::find(dim, dim + dimCount, 0) != (dim + dimCount);
413 }
414
tensorHasUnspecifiedDimensions(OperandType type,const std::vector<uint32_t> & dimensions)415 bool tensorHasUnspecifiedDimensions(OperandType type, const std::vector<uint32_t>& dimensions) {
416 return tensorHasUnspecifiedDimensions(static_cast<int>(type), dimensions.data(),
417 dimensions.size());
418 }
419
tensorHasUnspecifiedDimensions(const ANeuralNetworksOperandType * type)420 bool tensorHasUnspecifiedDimensions(const ANeuralNetworksOperandType* type) {
421 return tensorHasUnspecifiedDimensions(type->type, type->dimensions, type->dimensionCount);
422 }
423
tensorHasUnspecifiedDimensions(const Operand & operand)424 bool tensorHasUnspecifiedDimensions(const Operand& operand) {
425 return tensorHasUnspecifiedDimensions(static_cast<int>(operand.type), operand.dimensions.data(),
426 operand.dimensions.size());
427 }
428
alignBytesNeeded(uint32_t index,size_t length)429 uint32_t alignBytesNeeded(uint32_t index, size_t length) {
430 uint32_t pattern;
431 if (length < 2) {
432 pattern = 0; // No alignment necessary
433 } else if (length < 4) {
434 pattern = 1; // Align on 2-byte boundary
435 } else {
436 pattern = 3; // Align on 4-byte boundary
437 }
438 uint32_t extra = (~(index - 1)) & pattern;
439 return extra;
440 }
441
logModelToInfo(const V1_0::Model & model)442 void logModelToInfo(const V1_0::Model& model) {
443 LOG(INFO) << "V1_0::Model start";
444 LOG(INFO) << "operands" << toString(model.operands);
445 LOG(INFO) << "operations" << toString(model.operations);
446 LOG(INFO) << "inputIndexes" << toString(model.inputIndexes);
447 LOG(INFO) << "outputIndexes" << toString(model.outputIndexes);
448 LOG(INFO) << "operandValues size" << model.operandValues.size();
449 LOG(INFO) << "pools" << SHOW_IF_DEBUG(toString(model.pools));
450 }
451
logModelToInfo(const V1_1::Model & model)452 void logModelToInfo(const V1_1::Model& model) {
453 LOG(INFO) << "V1_1::Model start";
454 LOG(INFO) << "operands" << toString(model.operands);
455 LOG(INFO) << "operations" << toString(model.operations);
456 LOG(INFO) << "inputIndexes" << toString(model.inputIndexes);
457 LOG(INFO) << "outputIndexes" << toString(model.outputIndexes);
458 LOG(INFO) << "operandValues size " << model.operandValues.size();
459 LOG(INFO) << "pools" << SHOW_IF_DEBUG(toString(model.pools));
460 }
461
logModelToInfo(const V1_2::Model & model)462 void logModelToInfo(const V1_2::Model& model) {
463 LOG(INFO) << "V1_2::Model start";
464 LOG(INFO) << "operands" << toString(model.operands);
465 LOG(INFO) << "operations" << toString(model.operations);
466 LOG(INFO) << "inputIndexes" << toString(model.inputIndexes);
467 LOG(INFO) << "outputIndexes" << toString(model.outputIndexes);
468 LOG(INFO) << "operandValues size" << model.operandValues.size();
469 LOG(INFO) << "pools" << SHOW_IF_DEBUG(toString(model.pools));
470 LOG(INFO) << "relaxComputationFloat32toFloat16" << model.relaxComputationFloat32toFloat16;
471 LOG(INFO) << "extensionNameToPrefix" << toString(model.extensionNameToPrefix);
472 }
473
logSubgraphToInfo(std::string label,const V1_3::Subgraph & subgraph)474 static void logSubgraphToInfo(std::string label, const V1_3::Subgraph& subgraph) {
475 LOG(INFO) << label << ".operands" << toString(subgraph.operands);
476 LOG(INFO) << label << ".operations" << toString(subgraph.operations);
477 LOG(INFO) << label << ".inputIndexes" << toString(subgraph.inputIndexes);
478 LOG(INFO) << label << ".outputIndexes" << toString(subgraph.outputIndexes);
479 }
480
logModelToInfo(const V1_3::Model & model)481 void logModelToInfo(const V1_3::Model& model) {
482 LOG(INFO) << "V1_3::Model start";
483 logSubgraphToInfo("main", model.main);
484 for (uint32_t i = 0, n = model.referenced.size(); i < n; ++i) {
485 logSubgraphToInfo("referenced[" + std::to_string(i) + "]", model.referenced[i]);
486 }
487 LOG(INFO) << "operandValues size " << model.operandValues.size();
488 LOG(INFO) << "pools" << SHOW_IF_DEBUG(toString(model.pools));
489 LOG(INFO) << "relaxComputationFloat32toFloat16 " << model.relaxComputationFloat32toFloat16;
490 LOG(INFO) << "extensionNameToPrefix" << toString(model.extensionNameToPrefix);
491 }
492
validateOperandSymmPerChannelQuantParams(const Operand & halOperand,const ANeuralNetworksSymmPerChannelQuantParams & channelQuant,const char * tag)493 bool validateOperandSymmPerChannelQuantParams(
494 const Operand& halOperand, const ANeuralNetworksSymmPerChannelQuantParams& channelQuant,
495 const char* tag) {
496 if (halOperand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
497 return false;
498 }
499
500 NN_RET_CHECK_LT(channelQuant.channelDim, halOperand.dimensions.size()) << tag;
501 NN_RET_CHECK(channelQuant.scales != nullptr) << tag;
502 NN_RET_CHECK_EQ(channelQuant.scaleCount, halOperand.dimensions[channelQuant.channelDim]) << tag;
503 NN_RET_CHECK_NE(halOperand.dimensions[channelQuant.channelDim], 0u)
504 << tag << " channel dimension " << channelQuant.channelDim << " is underspecified";
505 for (uint32_t i = 0; i < halOperand.dimensions[channelQuant.channelDim]; i++) {
506 NN_RET_CHECK_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]";
507 }
508 return true;
509 }
510
validateScalarDimensions(const ANeuralNetworksOperandType & type,const char * tag)511 static bool validateScalarDimensions(const ANeuralNetworksOperandType& type, const char* tag) {
512 NN_RET_CHECK_EQ(type.dimensionCount, 0u) << tag << " invalid dimensions for scalar type";
513 NN_RET_CHECK(type.dimensions == nullptr) << tag << " invalid dimensions for scalar type";
514 return true;
515 }
516
validateQuant8AsymmParams(const ANeuralNetworksOperandType & type,const char * tag)517 static bool validateQuant8AsymmParams(const ANeuralNetworksOperandType& type, const char* tag) {
518 NN_RET_CHECK(0 <= type.zeroPoint && type.zeroPoint <= 255)
519 << tag << " invalid zeroPoint: " << type.zeroPoint;
520 NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
521 return true;
522 }
523
validateQuant8AsymmSignedParams(const ANeuralNetworksOperandType & type,const char * tag)524 static bool validateQuant8AsymmSignedParams(const ANeuralNetworksOperandType& type,
525 const char* tag) {
526 NN_RET_CHECK(-128 <= type.zeroPoint && type.zeroPoint <= 127)
527 << tag << " invalid zeroPoint: " << type.zeroPoint;
528 NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
529 return true;
530 }
531
validateQuant8SymmParams(const ANeuralNetworksOperandType & type,const char * tag)532 static bool validateQuant8SymmParams(const ANeuralNetworksOperandType& type, const char* tag) {
533 NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint;
534 NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
535 return true;
536 }
537
validateQuant16AsymmParams(const ANeuralNetworksOperandType & type,const char * tag)538 static bool validateQuant16AsymmParams(const ANeuralNetworksOperandType& type, const char* tag) {
539 NN_RET_CHECK(0 <= type.zeroPoint && type.zeroPoint <= 65535)
540 << tag << " invalid zeroPoint: " << type.zeroPoint;
541 NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
542 return true;
543 }
544
validateQuantSymmParams(const ANeuralNetworksOperandType & type,const char * tag)545 static bool validateQuantSymmParams(const ANeuralNetworksOperandType& type, const char* tag) {
546 NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
547 NN_RET_CHECK_GT(type.scale, 0.f) << tag << " invalid scale";
548 return true;
549 }
550
validateNoQuantParams(const ANeuralNetworksOperandType & type,const char * tag)551 static bool validateNoQuantParams(const ANeuralNetworksOperandType& type, const char* tag) {
552 NN_RET_CHECK_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
553 NN_RET_CHECK_EQ(type.scale, 0.f) << tag << " scale is not zero";
554 return true;
555 }
556
validateTensorDimensions(const ANeuralNetworksOperandType & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)557 static bool validateTensorDimensions(
558 const ANeuralNetworksOperandType& type,
559 const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag,
560 bool allowPartial) {
561 if (!allowPartial) {
562 NN_RET_CHECK_GT(type.dimensionCount, 0u) << tag << " invalid operand dimensions";
563 }
564 uint64_t size =
565 isExtensionOperandType(type.type)
566 ? extensionOperandTypeInfo->byteSize
567 : tableLookup(kSizeOfDataType, kSizeOfDataTypeOEM, static_cast<int>(type.type));
568 constexpr uint64_t kMaxSize = std::numeric_limits<uint32_t>::max();
569 for (uint32_t i = 0; i < type.dimensionCount; i++) {
570 if (!allowPartial) {
571 NN_RET_CHECK_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions";
572 }
573 if (type.dimensions[i] != 0) {
574 size *= type.dimensions[i];
575 NN_RET_CHECK_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize;
576 }
577 }
578 return true;
579 }
580
validateOperandTypeHelper(const ANeuralNetworksOperandType & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)581 static bool validateOperandTypeHelper(
582 const ANeuralNetworksOperandType& type,
583 const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag,
584 bool allowPartial) {
585 NN_RET_CHECK_EQ(type.dimensionCount == 0, type.dimensions == nullptr);
586 if (isExtensionOperandType(type.type)) {
587 NN_RET_CHECK(extensionOperandTypeInfo != nullptr);
588 if (extensionOperandTypeInfo->isTensor) {
589 NN_RET_CHECK(
590 validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
591 } else {
592 NN_RET_CHECK(validateScalarDimensions(type, tag));
593 }
594 return validateNoQuantParams(type, tag);
595 }
596
597 NN_RET_CHECK(extensionOperandTypeInfo == nullptr);
598 NN_RET_CHECK(validCode(kNumberOfDataTypes, kNumberOfDataTypesOEM, type.type))
599 << tag << " invalid OperandType: " << type.type;
600
601 bool isScalar = tableLookup(kScalarDataType, kScalarDataTypeOEM, type.type);
602 if (isScalar) {
603 NN_RET_CHECK(validateScalarDimensions(type, tag));
604 if (type.type != ANEURALNETWORKS_OEM_SCALAR) { // Historically, we have allowed OEM types
605 // to use quantization parameters.
606 NN_RET_CHECK(validateNoQuantParams(type, tag));
607 }
608 } else {
609 NN_RET_CHECK(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
610 if (type.type == ANEURALNETWORKS_TENSOR_QUANT8_ASYMM) {
611 NN_RET_CHECK(validateQuant8AsymmParams(type, tag));
612 } else if (type.type == ANEURALNETWORKS_TENSOR_QUANT8_ASYMM_SIGNED) {
613 NN_RET_CHECK(validateQuant8AsymmSignedParams(type, tag));
614 } else if (type.type == ANEURALNETWORKS_TENSOR_QUANT8_SYMM) {
615 NN_RET_CHECK(validateQuant8SymmParams(type, tag));
616 } else if (type.type == ANEURALNETWORKS_TENSOR_QUANT16_ASYMM) {
617 NN_RET_CHECK(validateQuant16AsymmParams(type, tag));
618 } else if (type.type == ANEURALNETWORKS_TENSOR_QUANT16_SYMM) {
619 NN_RET_CHECK(validateQuantSymmParams(type, tag));
620 } else if (type.type == ANEURALNETWORKS_TENSOR_INT32) {
621 // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters.
622 } else if (type.type == ANEURALNETWORKS_TENSOR_OEM_BYTE) {
623 // Historically, we have allowed OEM types to use quantization parameters.
624 } else {
625 NN_RET_CHECK(validateNoQuantParams(type, tag));
626 }
627 }
628
629 return true;
630 }
631
validateOperandType(const ANeuralNetworksOperandType & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)632 int validateOperandType(const ANeuralNetworksOperandType& type,
633 const Extension::OperandTypeInformation* const extensionOperandTypeInfo,
634 const char* tag, bool allowPartial) {
635 return validateOperandTypeHelper(type, extensionOperandTypeInfo, tag, allowPartial)
636 ? ANEURALNETWORKS_NO_ERROR
637 : ANEURALNETWORKS_BAD_DATA;
638 }
639
validateOperandList(uint32_t count,const uint32_t * list,uint32_t operandCount,const char * tag)640 int validateOperandList(uint32_t count, const uint32_t* list, uint32_t operandCount,
641 const char* tag) {
642 for (uint32_t i = 0; i < count; i++) {
643 if (list[i] >= operandCount) {
644 LOG(ERROR) << tag << " invalid operand index at " << i << " = " << list[i]
645 << ", operandCount " << operandCount;
646 return ANEURALNETWORKS_BAD_DATA;
647 }
648 }
649 return ANEURALNETWORKS_NO_ERROR;
650 }
651
validateOperationOperandTypes(const std::vector<Operand> & operands,uint32_t inOperandCount,const uint32_t * inOperandIndexes,const std::vector<OperandType> & inExpectedTypes,uint32_t outOperandCount,const uint32_t * outOperandIndexes,const std::vector<OperandType> & outExpectedInTypes)652 int validateOperationOperandTypes(const std::vector<Operand>& operands, uint32_t inOperandCount,
653 const uint32_t* inOperandIndexes,
654 const std::vector<OperandType>& inExpectedTypes,
655 uint32_t outOperandCount, const uint32_t* outOperandIndexes,
656 const std::vector<OperandType>& outExpectedInTypes) {
657 if (inOperandCount != static_cast<uint32_t>(inExpectedTypes.size()) ||
658 outOperandCount != static_cast<uint32_t>(outExpectedInTypes.size())) {
659 LOG(ERROR) << "Wrong operand count: expected " << inExpectedTypes.size() << " inputs and "
660 << outExpectedInTypes.size() << " outputs,"
661 << "got " << inOperandCount << " inputs and " << outOperandCount << " outputs";
662 return ANEURALNETWORKS_BAD_DATA;
663 }
664 for (uint32_t i = 0; i < inOperandCount; i++) {
665 if (operands[inOperandIndexes[i]].type != inExpectedTypes[i]) {
666 LOG(ERROR) << "Invalid input tensor type "
667 << toString(operands[inOperandIndexes[i]].type) << " for input " << i
668 << ", expected " << toString(inExpectedTypes[i]);
669 return ANEURALNETWORKS_BAD_DATA;
670 }
671 }
672 for (uint32_t i = 0; i < outOperandCount; i++) {
673 if (operands[outOperandIndexes[i]].type != outExpectedInTypes[i]) {
674 LOG(ERROR) << "Invalid output tensor type "
675 << toString(operands[outOperandIndexes[i]].type) << " for input " << i
676 << ", expected " << toString(outExpectedInTypes[i]);
677 return ANEURALNETWORKS_BAD_DATA;
678 }
679 }
680
681 return ANEURALNETWORKS_NO_ERROR;
682 }
683
validateHalVersion(ANeuralNetworksOperationType opType,HalVersion halVersion,HalVersion minSupportedHalVersion)684 static int validateHalVersion(ANeuralNetworksOperationType opType, HalVersion halVersion,
685 HalVersion minSupportedHalVersion) {
686 if (halVersion < minSupportedHalVersion) {
687 LOG(ERROR) << "The given inputs and outputs for operation " << getOperationName(opType)
688 << " are only supported in " << toString(minSupportedHalVersion)
689 << " and later (validating using " << toString(halVersion) << ")";
690 return ANEURALNETWORKS_BAD_DATA;
691 }
692 return ANEURALNETWORKS_NO_ERROR;
693 }
694
695 // Checks if two operands have the same types, ranks (if specified), dimensions
696 // (if specified), scales, zeroPoints, and extraParams.
compatible(const Operand & a,const Operand & b)697 static bool compatible(const Operand& a, const Operand& b) {
698 NN_RET_CHECK(a.type == b.type) << toString(a.type) << " != " << toString(b.type);
699 if (a.dimensions.size() != 0 && b.dimensions.size() != 0) {
700 NN_RET_CHECK_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions";
701 for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) {
702 if (a.dimensions[i] != 0 && b.dimensions[i] != 0) {
703 NN_RET_CHECK_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions";
704 }
705 }
706 }
707 NN_RET_CHECK_EQ(a.scale, b.scale);
708 NN_RET_CHECK_EQ(a.zeroPoint, b.zeroPoint);
709 NN_RET_CHECK(a.extraParams == b.extraParams)
710 << toString(a.extraParams) << " != " << toString(b.extraParams);
711 return true;
712 }
713
validateConditionOperand(const Operand & operand)714 static bool validateConditionOperand(const Operand& operand) {
715 NN_RET_CHECK(operand.type == OperandType::TENSOR_BOOL8)
716 << "Unexpected condition operand type: " << toString(operand.type);
717 NN_RET_CHECK_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton";
718 NN_RET_CHECK_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton";
719 return true;
720 }
721
checkSubgraphValidationHelper(const SubgraphValidationHelper & helper)722 static void checkSubgraphValidationHelper(const SubgraphValidationHelper& helper) {
723 CHECK(helper.isValidSubgraphReference != nullptr);
724 CHECK(helper.getSubgraphInputCount != nullptr);
725 CHECK(helper.getSubgraphOutputCount != nullptr);
726 CHECK(helper.getSubgraphInputOperand != nullptr);
727 CHECK(helper.getSubgraphOutputOperand != nullptr);
728 }
729
validateIfOperation(uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs,const std::vector<Operand> & operands,const SubgraphValidationHelper & helper)730 static bool validateIfOperation(uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
731 const uint32_t* outputs, const std::vector<Operand>& operands,
732 const SubgraphValidationHelper& helper) {
733 namespace op = operation_if;
734 checkSubgraphValidationHelper(helper);
735 NN_RET_CHECK_GE(inputCount, 3u) << "ANEURALNETWORKS_IF must have at least 3 inputs";
736 NN_RET_CHECK_GE(outputCount, 1u) << "ANEURALNETWORKS_IF must have at least 1 output";
737 auto validateBranchOperand = [&](const Operand& branchModelOperand) -> bool {
738 NN_RET_CHECK(helper.isValidSubgraphReference(branchModelOperand))
739 << "Operand is not a valid subgraph reference";
740 const uint32_t branchModelInputCount = helper.getSubgraphInputCount(branchModelOperand);
741 const uint32_t branchModelOutputCount = helper.getSubgraphOutputCount(branchModelOperand);
742 NN_RET_CHECK_EQ(inputCount, op::kFirstInput + branchModelInputCount);
743 NN_RET_CHECK_EQ(outputCount, branchModelOutputCount);
744 for (uint32_t i = 0; i < branchModelInputCount; ++i) {
745 const Operand& innerOperand = *helper.getSubgraphInputOperand(branchModelOperand, i);
746 const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
747 NN_RET_CHECK(compatible(innerOperand, outerOperand));
748 }
749 for (uint32_t i = 0; i < branchModelOutputCount; ++i) {
750 const Operand& innerOperand = *helper.getSubgraphOutputOperand(branchModelOperand, i);
751 const Operand& outerOperand = operands[outputs[i]];
752 NN_RET_CHECK(compatible(innerOperand, outerOperand));
753 }
754 return true;
755 };
756 NN_RET_CHECK(validateConditionOperand(operands[inputs[op::kCondBoolOperand]]))
757 << "Validation failed for IF condition operand";
758 NN_RET_CHECK(validateBranchOperand(operands[inputs[op::kThenModelOperand]]))
759 << "Validation failed for IF then model";
760 NN_RET_CHECK(validateBranchOperand(operands[inputs[op::kElseModelOperand]]))
761 << "Validation failed for IF else model";
762 return true;
763 }
764
validateControlFlowOperandUnknownSize(const SubgraphValidationHelper & helper,const Operand & operand)765 static bool validateControlFlowOperandUnknownSize(const SubgraphValidationHelper& helper,
766 const Operand& operand) {
767 if (!helper.allowControlFlowOperationWithOperandOfUnknownSize &&
768 !isExtensionOperandType(operand.type)) {
769 NN_RET_CHECK_NE(nonExtensionOperandSizeOfData(operand.type, operand.dimensions), 0u);
770 }
771 return true;
772 }
773
validateWhileOperation(uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs,const std::vector<Operand> & operands,const SubgraphValidationHelper & helper)774 static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs,
775 uint32_t outputCount, const uint32_t* outputs,
776 const std::vector<Operand>& operands,
777 const SubgraphValidationHelper& helper) {
778 // Let the loop have
779 // - m >= 1 input-output operands,
780 // - k >= 0 state-only operands, and
781 // - n >= 0 input-only operands.
782 // Then
783 // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs.
784 // - the condition model has (m + k + n) inputs and 1 output.
785 // - the body model has (m + k + n) inputs and (m + k) outputs.
786 namespace op = operation_while;
787 checkSubgraphValidationHelper(helper);
788 NN_RET_CHECK_GE(inputCount, 3u) << "ANEURALNETWORKS_WHILE must have at least 3 inputs";
789 NN_RET_CHECK_GE(outputCount, 1u) << "ANEURALNETWORKS_WHILE must have at least 1 output";
790 auto validateCondOperand = [&](const Operand& condModelOperand) -> bool {
791 NN_RET_CHECK(helper.isValidSubgraphReference(condModelOperand))
792 << "Operand is not a valid subgraph reference";
793 const uint32_t condModelInputCount = helper.getSubgraphInputCount(condModelOperand);
794 const uint32_t condModelOutputCount = helper.getSubgraphOutputCount(condModelOperand);
795 NN_RET_CHECK_EQ(inputCount, op::kFirstInput + condModelInputCount);
796 NN_RET_CHECK_EQ(condModelOutputCount, 1u);
797 for (uint32_t i = 0; i < condModelInputCount; ++i) {
798 const Operand& innerOperand = *helper.getSubgraphInputOperand(condModelOperand, i);
799 const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
800 NN_RET_CHECK(compatible(innerOperand, outerOperand));
801 NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, innerOperand));
802 NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand));
803 }
804 NN_RET_CHECK(
805 validateConditionOperand(*helper.getSubgraphOutputOperand(condModelOperand, 0)));
806 return true;
807 };
808 auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> bool {
809 NN_RET_CHECK(helper.isValidSubgraphReference(bodyModelOperand))
810 << "Operand is not a valid subgraph reference";
811 const uint32_t bodyModelInputCount = helper.getSubgraphInputCount(bodyModelOperand);
812 const uint32_t bodyModelOutputCount = helper.getSubgraphOutputCount(bodyModelOperand);
813 NN_RET_CHECK_EQ(inputCount, op::kFirstInput + bodyModelInputCount);
814 NN_RET_CHECK_GE(bodyModelOutputCount, outputCount);
815 NN_RET_CHECK_GE(bodyModelInputCount, bodyModelOutputCount);
816 const uint32_t inputOutputCount = outputCount;
817 const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount;
818 const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount;
819 for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) {
820 const Operand& innerOperand = *helper.getSubgraphInputOperand(bodyModelOperand, i);
821 const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
822 NN_RET_CHECK(compatible(innerOperand, outerOperand));
823 NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, innerOperand));
824 NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand));
825 }
826 for (uint32_t i = 0; i < inputOutputCount; ++i) {
827 const Operand& innerOperand = *helper.getSubgraphOutputOperand(bodyModelOperand, i);
828 const Operand& outerOperand = operands[outputs[i]];
829 NN_RET_CHECK(compatible(innerOperand, outerOperand));
830 NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outerOperand));
831 }
832 for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
833 const Operand& inputOperand = *helper.getSubgraphInputOperand(bodyModelOperand, i);
834 const Operand& outputOperand = *helper.getSubgraphOutputOperand(bodyModelOperand, i);
835 NN_RET_CHECK(compatible(inputOperand, outputOperand));
836 NN_RET_CHECK(validateControlFlowOperandUnknownSize(helper, outputOperand));
837 }
838 return true;
839 };
840 NN_RET_CHECK(validateCondOperand(operands[inputs[op::kCondModelOperand]]))
841 << "Validation failed for WHILE condition model";
842 NN_RET_CHECK(validateBodyOperand(operands[inputs[op::kBodyModelOperand]]))
843 << "Validation failed for WHILE body model";
844 return true;
845 }
846
validateOperation(ANeuralNetworksOperationType opType,uint32_t inputCount,const uint32_t * inputIndexes,uint32_t outputCount,const uint32_t * outputIndexes,const std::vector<hal::Operand> & operands,HalVersion halVersion)847 static inline int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
848 const uint32_t* inputIndexes, uint32_t outputCount,
849 const uint32_t* outputIndexes,
850 const std::vector<hal::Operand>& operands,
851 HalVersion halVersion) {
852 if (opType == ANEURALNETWORKS_IF || opType == ANEURALNETWORKS_WHILE) {
853 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
854 LOG(ERROR) << "This validateOperation() overload does not support control flow";
855 return ANEURALNETWORKS_BAD_DATA;
856 }
857 return validateOperation(opType, inputCount, inputIndexes, outputCount, outputIndexes, operands,
858 halVersion, {});
859 }
860
validateOperation(ANeuralNetworksOperationType opType,uint32_t inputCount,const uint32_t * inputIndexes,uint32_t outputCount,const uint32_t * outputIndexes,const std::vector<Operand> & operands,HalVersion halVersion,const SubgraphValidationHelper & helper)861 int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
862 const uint32_t* inputIndexes, uint32_t outputCount,
863 const uint32_t* outputIndexes, const std::vector<Operand>& operands,
864 HalVersion halVersion, const SubgraphValidationHelper& helper) {
865 NN_RETURN_IF_ERROR(validateOperandList(inputCount, inputIndexes,
866 static_cast<uint32_t>(operands.size()),
867 "ANeuralNetworksModel_addOperation inputs"));
868 NN_RETURN_IF_ERROR(validateOperandList(outputCount, outputIndexes,
869 static_cast<uint32_t>(operands.size()),
870 "ANeuralNetworksModel_addOperation outputs"));
871
872 if (isExtensionOperationType(opType)) {
873 if (halVersion < HalVersion::V1_2) {
874 LOG(ERROR)
875 << "Extension operations are supported since HAL version 1.2, validating using "
876 << toString(halVersion);
877 return ANEURALNETWORKS_BAD_DATA;
878 }
879 // There is no other validation we can do for an extension operation.
880 return ANEURALNETWORKS_NO_ERROR;
881 }
882
883 auto logInvalidInOutNumber = [opType, inputCount, outputCount](int expIn, int expOut) {
884 LOG(ERROR) << "Invalid number of input operands (" << inputCount << ", expected " << expIn
885 << ") or output operands (" << outputCount << ", expected " << expOut
886 << ") for operation " << getOperationName(opType);
887 };
888
889 switch (opType) {
890 case ANEURALNETWORKS_OEM_OPERATION: {
891 return ANEURALNETWORKS_NO_ERROR;
892 }
893 case ANEURALNETWORKS_RESHAPE: {
894 if (inputCount != 2 || outputCount != 1) {
895 logInvalidInOutNumber(2, 1);
896 return ANEURALNETWORKS_BAD_DATA;
897 }
898 auto inputType = operands[inputIndexes[0]].type;
899 std::vector<OperandType> inExpectedTypes;
900 std::vector<OperandType> outExpectedTypes;
901 if (inputType == OperandType::TENSOR_FLOAT32) {
902 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
903 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32};
904 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
905 } else if (inputType == OperandType::TENSOR_FLOAT16) {
906 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
907 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32};
908 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
909 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
910 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
911 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32};
912 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
913 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
914 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
915 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
916 OperandType::TENSOR_INT32};
917 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
918 } else {
919 LOG(ERROR) << "Unsupported input tensor type for operation "
920 << getOperationName(opType);
921 return ANEURALNETWORKS_BAD_DATA;
922 }
923 const auto inputRank = operands[inputIndexes[0]].dimensions.size();
924 if (inputRank > 4) {
925 LOG(ERROR) << "Unsupported input tensor rank for operation "
926 << getOperationName(opType);
927 return ANEURALNETWORKS_BAD_DATA;
928 }
929 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
930 inExpectedTypes, outputCount, outputIndexes,
931 outExpectedTypes);
932 }
933 case ANEURALNETWORKS_DEPTH_TO_SPACE: {
934 if ((inputCount != 3 && inputCount != 2) || outputCount != 1) {
935 LOG(ERROR) << "Invalid number of input operands (" << inputCount
936 << ", expected 3 or 2) or output operands (" << outputCount
937 << ", expected 1) for operation " << getOperationName(opType);
938 return ANEURALNETWORKS_BAD_DATA;
939 }
940 auto inputType = operands[inputIndexes[0]].type;
941 std::vector<OperandType> inExpectedTypes;
942 std::vector<OperandType> outExpectedTypes;
943 if (inputType == OperandType::TENSOR_FLOAT32) {
944 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
945 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
946 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
947 } else if (inputType == OperandType::TENSOR_FLOAT16) {
948 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
949 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
950 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
951 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
952 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
953 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
954 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
955 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
956 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
957 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
958 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
959 } else {
960 LOG(ERROR) << "Unsupported input tensor type for operation "
961 << getOperationName(opType);
962 return ANEURALNETWORKS_BAD_DATA;
963 }
964 if (inputCount == 3) {
965 inExpectedTypes.push_back(OperandType::BOOL);
966 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
967 } else {
968 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
969 }
970 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
971 inExpectedTypes, outputCount, outputIndexes,
972 outExpectedTypes);
973 }
974 case ANEURALNETWORKS_SPACE_TO_DEPTH: {
975 if ((inputCount != 3 && inputCount != 2) || outputCount != 1) {
976 LOG(ERROR) << "Invalid number of input operands (" << inputCount
977 << ", expected 3 or 2) or output operands (" << outputCount
978 << ", expected 1) for operation " << getOperationName(opType);
979 return ANEURALNETWORKS_BAD_DATA;
980 }
981 auto inputType = operands[inputIndexes[0]].type;
982 std::vector<OperandType> inExpectedTypes;
983 std::vector<OperandType> outExpectedTypes;
984 if (inputType == OperandType::TENSOR_FLOAT32) {
985 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
986 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
987 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
988 } else if (inputType == OperandType::TENSOR_FLOAT16) {
989 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
990 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
991 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
992 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
993 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
994 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
995 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
996 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
997 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
998 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
999 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1000 } else {
1001 LOG(ERROR) << "Unsupported input tensor type for operation "
1002 << getOperationName(opType);
1003 return ANEURALNETWORKS_BAD_DATA;
1004 }
1005 if (inputCount == 3) {
1006 inExpectedTypes.push_back(OperandType::BOOL);
1007 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1008 } else {
1009 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1010 }
1011 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1012 inExpectedTypes, outputCount, outputIndexes,
1013 outExpectedTypes);
1014 }
1015 case ANEURALNETWORKS_EMBEDDING_LOOKUP: {
1016 if (inputCount != 2 || outputCount != 1) {
1017 logInvalidInOutNumber(2, 1);
1018 return ANEURALNETWORKS_BAD_DATA;
1019 }
1020 auto inputType = operands[inputIndexes[1]].type;
1021 if (inputType != OperandType::TENSOR_FLOAT16 &&
1022 inputType != OperandType::TENSOR_FLOAT32 &&
1023 inputType != OperandType::TENSOR_INT32 &&
1024 inputType != OperandType::TENSOR_QUANT8_ASYMM &&
1025 inputType != OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1026 LOG(ERROR) << "Unsupported input tensor type for operation "
1027 << getOperationName(opType);
1028 return ANEURALNETWORKS_BAD_DATA;
1029 }
1030 std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, inputType};
1031 std::vector<OperandType> outExpectedTypes = {inputType};
1032 if (inputType == OperandType::TENSOR_FLOAT16 ||
1033 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1034 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1035 } else if (inputType == OperandType::TENSOR_INT32 ||
1036 inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1037 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1038 } else {
1039 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1040 }
1041 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1042 inExpectedTypes, outputCount, outputIndexes,
1043 outExpectedTypes);
1044 }
1045 case ANEURALNETWORKS_HASHTABLE_LOOKUP: {
1046 if (inputCount != 3 || outputCount != 2) {
1047 logInvalidInOutNumber(3, 2);
1048 return ANEURALNETWORKS_BAD_DATA;
1049 }
1050 auto inputType = operands[inputIndexes[2]].type;
1051 if (inputType != OperandType::TENSOR_FLOAT32 &&
1052 inputType != OperandType::TENSOR_INT32 &&
1053 inputType != OperandType::TENSOR_QUANT8_ASYMM) {
1054 LOG(ERROR) << "Unsupported input tensor type for operation "
1055 << getOperationName(opType);
1056 return ANEURALNETWORKS_BAD_DATA;
1057 }
1058 std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32,
1059 OperandType::TENSOR_INT32, inputType};
1060 std::vector<OperandType> outExpectedTypes = {inputType,
1061 OperandType::TENSOR_QUANT8_ASYMM};
1062 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1063 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1064 inExpectedTypes, outputCount, outputIndexes,
1065 outExpectedTypes);
1066 }
1067 case ANEURALNETWORKS_LSH_PROJECTION: {
1068 if (inputCount != 4 || outputCount != 1) {
1069 logInvalidInOutNumber(4, 1);
1070 return ANEURALNETWORKS_BAD_DATA;
1071 }
1072 auto inputType = operands[inputIndexes[1]].type;
1073 if (inputType != OperandType::TENSOR_FLOAT16 &&
1074 inputType != OperandType::TENSOR_FLOAT32 &&
1075 inputType != OperandType::TENSOR_INT32 &&
1076 inputType != OperandType::TENSOR_QUANT8_ASYMM) {
1077 LOG(ERROR) << "Unsupported input tensor type for operation "
1078 << getOperationName(opType);
1079 return ANEURALNETWORKS_BAD_DATA;
1080 }
1081 auto hashType = operands[inputIndexes[0]].type;
1082 std::vector<OperandType> inExpectedTypes;
1083 if (hashType == OperandType::TENSOR_FLOAT16) {
1084 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1085 inExpectedTypes = {
1086 OperandType::TENSOR_FLOAT16,
1087 inputType,
1088 OperandType::TENSOR_FLOAT16,
1089 OperandType::INT32,
1090 };
1091 } else if (hashType == OperandType::TENSOR_FLOAT32) {
1092 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1093 inExpectedTypes = {
1094 OperandType::TENSOR_FLOAT32,
1095 inputType,
1096 OperandType::TENSOR_FLOAT32,
1097 OperandType::INT32,
1098 };
1099 } else {
1100 LOG(ERROR) << "Unsupported hash tensor type for operation "
1101 << getOperationName(opType);
1102 return ANEURALNETWORKS_BAD_DATA;
1103 }
1104 std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
1105 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1106 inExpectedTypes, outputCount, outputIndexes,
1107 outExpectedTypes);
1108 }
1109 case ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM: {
1110 const uint32_t kNumOutputs = 2;
1111 const uint32_t kNumOutputsMerged = 1;
1112 const uint32_t kNumOutputsWithState = 6;
1113 const uint32_t kNumOutputsMergedWithState = 5;
1114 if (inputCount != 61 ||
1115 (outputCount != kNumOutputs && outputCount != kNumOutputsMerged &&
1116 outputCount != kNumOutputsWithState &&
1117 outputCount != kNumOutputsMergedWithState)) {
1118 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1119 << ", expected 61) or output operands (" << outputCount
1120 << ", expected 1, 2, 5 or 6) for operation " << getOperationName(opType);
1121 return ANEURALNETWORKS_BAD_DATA;
1122 }
1123
1124 std::vector<OperandType> inExpectedTypes;
1125 auto inputType = operands[inputIndexes[0]].type;
1126 if (inputType != OperandType::TENSOR_FLOAT32 &&
1127 inputType != OperandType::TENSOR_FLOAT16) {
1128 LOG(ERROR) << "Unsupported input tensor type for operation "
1129 << getOperationName(opType);
1130 return ANEURALNETWORKS_BAD_DATA;
1131 }
1132
1133 inExpectedTypes = {};
1134 for (int i = 0; i < 48; ++i) {
1135 inExpectedTypes.push_back(inputType);
1136 }
1137 inExpectedTypes.push_back(OperandType::INT32);
1138 inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
1139 ? OperandType::FLOAT32
1140 : OperandType::FLOAT16);
1141 inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
1142 ? OperandType::FLOAT32
1143 : OperandType::FLOAT16);
1144 inExpectedTypes.push_back(OperandType::BOOL);
1145 inExpectedTypes.push_back(OperandType::BOOL);
1146 for (int i = 0; i < 8; ++i) {
1147 inExpectedTypes.push_back(inputType);
1148 }
1149
1150 HalVersion minSupportedHalVersion = HalVersion::V1_2;
1151 if (outputCount == kNumOutputsWithState || outputCount == kNumOutputsMergedWithState) {
1152 minSupportedHalVersion = HalVersion::V1_3;
1153 }
1154 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, minSupportedHalVersion));
1155 std::vector<OperandType> outExpectedTypes(outputCount, inputType);
1156 auto status = validateOperationOperandTypes(operands, inputCount, inputIndexes,
1157 inExpectedTypes, outputCount, outputIndexes,
1158 outExpectedTypes);
1159 return status;
1160 }
1161 case ANEURALNETWORKS_LSTM: {
1162 if ((inputCount != 23 && inputCount != 27) || outputCount != 4) {
1163 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1164 << ", expected 23 or 27) or output operands (" << outputCount
1165 << ", expected 4) for operation " << getOperationName(opType);
1166 return ANEURALNETWORKS_BAD_DATA;
1167 }
1168 std::vector<OperandType> inExpectedTypes;
1169 std::vector<OperandType> outExpectedTypes;
1170 auto inputType = operands[inputIndexes[0]].type;
1171 if (inputType != OperandType::TENSOR_FLOAT32 &&
1172 inputType != OperandType::TENSOR_FLOAT16) {
1173 LOG(ERROR) << "Unsupported input tensor type for operation "
1174 << getOperationName(opType);
1175 return ANEURALNETWORKS_BAD_DATA;
1176 }
1177
1178 inExpectedTypes = {inputType, inputType, inputType, inputType, inputType,
1179 inputType, inputType, inputType, inputType, inputType,
1180 inputType, inputType, inputType, inputType, inputType,
1181 inputType, inputType, inputType, inputType, inputType,
1182 OperandType::INT32};
1183 if (inputType == OperandType::TENSOR_FLOAT32) {
1184 inExpectedTypes.push_back(OperandType::FLOAT32);
1185 inExpectedTypes.push_back(OperandType::FLOAT32);
1186 } else {
1187 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1188 inExpectedTypes.push_back(OperandType::FLOAT16);
1189 inExpectedTypes.push_back(OperandType::FLOAT16);
1190 }
1191
1192 outExpectedTypes = {inputType, inputType, inputType, inputType};
1193 if (inputCount == 23) {
1194 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1195 } else {
1196 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1197 for (int i = 0; i < 4; ++i) {
1198 inExpectedTypes.push_back(inputType);
1199 }
1200 }
1201 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1202 inExpectedTypes, outputCount, outputIndexes,
1203 outExpectedTypes);
1204 }
1205 case ANEURALNETWORKS_QUANTIZED_16BIT_LSTM: {
1206 if (inputCount != 15 || outputCount != 2) {
1207 logInvalidInOutNumber(15, 2);
1208 return ANEURALNETWORKS_BAD_DATA;
1209 }
1210 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1211 std::vector<OperandType> inExpectedTypes = {
1212 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
1213 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
1214 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
1215 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
1216 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32,
1217 OperandType::TENSOR_INT32, OperandType::TENSOR_INT32,
1218 OperandType::TENSOR_INT32, OperandType::TENSOR_QUANT16_SYMM,
1219 OperandType::TENSOR_QUANT8_ASYMM};
1220 std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_QUANT16_SYMM,
1221 OperandType::TENSOR_QUANT8_ASYMM};
1222 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1223 inExpectedTypes, outputCount, outputIndexes,
1224 outExpectedTypes);
1225 }
1226 case ANEURALNETWORKS_RANDOM_MULTINOMIAL: {
1227 if (inputCount != 3 || outputCount != 1) {
1228 logInvalidInOutNumber(3, 1);
1229 return ANEURALNETWORKS_BAD_DATA;
1230 }
1231 OperandType inputType = operands[inputIndexes[0]].type;
1232 std::vector<OperandType> inExpectedTypes;
1233 if (inputType == OperandType::TENSOR_FLOAT32 ||
1234 inputType == OperandType::TENSOR_FLOAT16) {
1235 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1236 inExpectedTypes = {
1237 inputType,
1238 OperandType::INT32,
1239 OperandType::TENSOR_INT32,
1240 };
1241 } else {
1242 LOG(ERROR) << "Unsupported input tensor type for operation "
1243 << getOperationName(opType);
1244 return ANEURALNETWORKS_BAD_DATA;
1245 }
1246 std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
1247 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1248 inExpectedTypes, outputCount, outputIndexes,
1249 outExpectedTypes);
1250 }
1251 case ANEURALNETWORKS_RNN: {
1252 if (inputCount != 6 || outputCount != 2) {
1253 logInvalidInOutNumber(6, 2);
1254 return ANEURALNETWORKS_BAD_DATA;
1255 }
1256 OperandType inputType = operands[inputIndexes[0]].type;
1257 std::vector<OperandType> inExpectedTypes;
1258 std::vector<OperandType> outExpectedTypes;
1259 if (inputType == OperandType::TENSOR_FLOAT32) {
1260 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1261 inExpectedTypes = {
1262 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
1263 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
1264 OperandType::TENSOR_FLOAT32, OperandType::INT32,
1265 };
1266 outExpectedTypes = {
1267 OperandType::TENSOR_FLOAT32,
1268 OperandType::TENSOR_FLOAT32,
1269 };
1270 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1271 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1272 inExpectedTypes = {
1273 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
1274 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
1275 OperandType::TENSOR_FLOAT16, OperandType::INT32,
1276 };
1277 outExpectedTypes = {
1278 OperandType::TENSOR_FLOAT16,
1279 OperandType::TENSOR_FLOAT16,
1280 };
1281 } else {
1282 LOG(ERROR) << "Unsupported input tensor type for operation "
1283 << getOperationName(opType);
1284 return ANEURALNETWORKS_BAD_DATA;
1285 }
1286 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1287 inExpectedTypes, outputCount, outputIndexes,
1288 outExpectedTypes);
1289 }
1290 case ANEURALNETWORKS_SVDF: {
1291 if (inputCount != 7 || outputCount != 2) {
1292 logInvalidInOutNumber(7, 2);
1293 return ANEURALNETWORKS_BAD_DATA;
1294 }
1295 OperandType inputType = operands[inputIndexes[0]].type;
1296 if (inputType == OperandType::TENSOR_FLOAT32) {
1297 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
1298
1299 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1300 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1301 } else {
1302 LOG(ERROR) << "Unsupported input tensor type for operation "
1303 << getOperationName(opType);
1304 return ANEURALNETWORKS_BAD_DATA;
1305 }
1306 std::vector<OperandType> inExpectedTypes = {
1307 inputType, inputType, inputType, inputType,
1308 inputType, OperandType::INT32, OperandType::INT32,
1309 };
1310 std::vector<OperandType> outExpectedTypes = {inputType, inputType};
1311 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1312 inExpectedTypes, outputCount, outputIndexes,
1313 outExpectedTypes);
1314 }
1315 case ANEURALNETWORKS_BATCH_TO_SPACE_ND: {
1316 if ((inputCount != 3 && inputCount != 2) || outputCount != 1) {
1317 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1318 << ", expected 3 or 2) or output operands (" << outputCount
1319 << ", expected 1) for operation " << getOperationName(opType);
1320 return ANEURALNETWORKS_BAD_DATA;
1321 }
1322 auto inputType = operands[inputIndexes[0]].type;
1323 std::vector<OperandType> inExpectedTypes;
1324 std::vector<OperandType> outExpectedTypes;
1325 if (inputType == OperandType::TENSOR_FLOAT32) {
1326 inExpectedTypes = {
1327 OperandType::TENSOR_FLOAT32,
1328 OperandType::TENSOR_INT32,
1329 };
1330 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1331 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1332 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1333 inExpectedTypes = {
1334 OperandType::TENSOR_FLOAT16,
1335 OperandType::TENSOR_INT32,
1336 };
1337 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1338 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1339 inExpectedTypes = {
1340 OperandType::TENSOR_QUANT8_ASYMM,
1341 OperandType::TENSOR_INT32,
1342 };
1343 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1344 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1345 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1346 inExpectedTypes = {
1347 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
1348 OperandType::TENSOR_INT32,
1349 };
1350 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1351 } else {
1352 LOG(ERROR) << "Unsupported input tensor type for operation "
1353 << getOperationName(opType);
1354 return ANEURALNETWORKS_BAD_DATA;
1355 }
1356 if (inputCount == 3) {
1357 inExpectedTypes.push_back(OperandType::BOOL);
1358 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1359 } else {
1360 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1361 }
1362 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1363 inExpectedTypes, outputCount, outputIndexes,
1364 outExpectedTypes);
1365 }
1366 case ANEURALNETWORKS_SPACE_TO_BATCH_ND: {
1367 if ((inputCount != 4 && inputCount != 3) || outputCount != 1) {
1368 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1369 << ", expected 4 or 3) or output operands (" << outputCount
1370 << ", expected 1) for operation " << getOperationName(opType);
1371 return ANEURALNETWORKS_BAD_DATA;
1372 }
1373 auto inputType = operands[inputIndexes[0]].type;
1374 std::vector<OperandType> inExpectedTypes;
1375 std::vector<OperandType> outExpectedTypes;
1376 if (inputType == OperandType::TENSOR_FLOAT32) {
1377 inExpectedTypes = {
1378 OperandType::TENSOR_FLOAT32,
1379 OperandType::TENSOR_INT32,
1380 OperandType::TENSOR_INT32,
1381 };
1382 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1383 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1384 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1385 inExpectedTypes = {
1386 OperandType::TENSOR_FLOAT16,
1387 OperandType::TENSOR_INT32,
1388 OperandType::TENSOR_INT32,
1389 };
1390 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1391 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1392 if (operands[inputIndexes[0]].zeroPoint != 0) {
1393 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1394 }
1395 inExpectedTypes = {
1396 OperandType::TENSOR_QUANT8_ASYMM,
1397 OperandType::TENSOR_INT32,
1398 OperandType::TENSOR_INT32,
1399 };
1400 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1401 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1402 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1403 inExpectedTypes = {
1404 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
1405 OperandType::TENSOR_INT32,
1406 OperandType::TENSOR_INT32,
1407 };
1408 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1409 } else {
1410 LOG(ERROR) << "Unsupported input tensor type for operation "
1411 << getOperationName(opType);
1412 return ANEURALNETWORKS_BAD_DATA;
1413 }
1414 if (inputCount == 4) {
1415 inExpectedTypes.push_back(OperandType::BOOL);
1416 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1417 } else {
1418 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1419 }
1420 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1421 inExpectedTypes, outputCount, outputIndexes,
1422 outExpectedTypes);
1423 }
1424 case ANEURALNETWORKS_PAD: {
1425 if (inputCount != 2 || outputCount != 1) {
1426 logInvalidInOutNumber(2, 1);
1427 return ANEURALNETWORKS_BAD_DATA;
1428 }
1429 auto inputType = operands[inputIndexes[0]].type;
1430 std::vector<OperandType> inExpectedTypes;
1431 std::vector<OperandType> outExpectedTypes;
1432 if (inputType == OperandType::TENSOR_FLOAT32) {
1433 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1434 inExpectedTypes = {
1435 OperandType::TENSOR_FLOAT32,
1436 OperandType::TENSOR_INT32,
1437 };
1438 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1439 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1440 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1441 inExpectedTypes = {
1442 OperandType::TENSOR_FLOAT16,
1443 OperandType::TENSOR_INT32,
1444 };
1445 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1446 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1447 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1448 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1449 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1450 } else {
1451 if (operands[inputIndexes[0]].zeroPoint == 0) {
1452 NN_RETURN_IF_ERROR(
1453 validateHalVersion(opType, halVersion, HalVersion::V1_1));
1454 } else {
1455 NN_RETURN_IF_ERROR(
1456 validateHalVersion(opType, halVersion, HalVersion::V1_2));
1457 }
1458 }
1459 inExpectedTypes = {
1460 inputType,
1461 OperandType::TENSOR_INT32,
1462 };
1463 outExpectedTypes = {inputType};
1464 } else {
1465 LOG(ERROR) << "Unsupported input tensor type for operation "
1466 << getOperationName(opType);
1467 return ANEURALNETWORKS_BAD_DATA;
1468 }
1469 const auto inputRank = operands[inputIndexes[0]].dimensions.size();
1470 if (inputRank > 4) {
1471 LOG(ERROR) << "Unsupported input tensor rank for operation "
1472 << getOperationName(opType);
1473 return ANEURALNETWORKS_BAD_DATA;
1474 }
1475 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1476 inExpectedTypes, outputCount, outputIndexes,
1477 outExpectedTypes);
1478 }
1479 case ANEURALNETWORKS_PAD_V2: {
1480 if (inputCount != 3 || outputCount != 1) {
1481 logInvalidInOutNumber(3, 1);
1482 return ANEURALNETWORKS_BAD_DATA;
1483 }
1484 auto inputType = operands[inputIndexes[0]].type;
1485 std::vector<OperandType> inExpectedTypes;
1486 std::vector<OperandType> outExpectedTypes;
1487 if (inputType == OperandType::TENSOR_FLOAT32) {
1488 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1489 inExpectedTypes = {
1490 OperandType::TENSOR_FLOAT32,
1491 OperandType::TENSOR_INT32,
1492 OperandType::FLOAT32,
1493 };
1494 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1495 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1496 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1497 inExpectedTypes = {
1498 OperandType::TENSOR_FLOAT16,
1499 OperandType::TENSOR_INT32,
1500 OperandType::FLOAT16,
1501 };
1502 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1503 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1504 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1505 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1506 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1507 } else {
1508 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1509 }
1510 inExpectedTypes = {
1511 inputType,
1512 OperandType::TENSOR_INT32,
1513 OperandType::INT32,
1514 }; // TODO(b/116699425): Make it UINT8.
1515 outExpectedTypes = {inputType};
1516 } else {
1517 LOG(ERROR) << "Unsupported input tensor type for operation "
1518 << getOperationName(opType);
1519 return ANEURALNETWORKS_BAD_DATA;
1520 }
1521 const auto inputRank = operands[inputIndexes[0]].dimensions.size();
1522 if (inputRank > 4) {
1523 LOG(ERROR) << "Unsupported input tensor rank for operation "
1524 << getOperationName(opType);
1525 return ANEURALNETWORKS_BAD_DATA;
1526 }
1527 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1528 inExpectedTypes, outputCount, outputIndexes,
1529 outExpectedTypes);
1530 }
1531 case ANEURALNETWORKS_CAST: {
1532 if (inputCount != 1 || outputCount != 1) {
1533 logInvalidInOutNumber(1, 1);
1534 return ANEURALNETWORKS_BAD_DATA;
1535 }
1536 auto inputOperand = operands[inputIndexes[0]];
1537 auto outputOperand = operands[outputIndexes[0]];
1538 auto inputType = inputOperand.type;
1539 auto outputType = outputOperand.type;
1540 std::vector<OperandType> inExpectedTypes;
1541 std::vector<OperandType> outExpectedTypes;
1542 if ((inputType == OperandType::TENSOR_FLOAT16 ||
1543 inputType == OperandType::TENSOR_FLOAT32 ||
1544 inputType == OperandType::TENSOR_INT32 ||
1545 inputType == OperandType::TENSOR_QUANT8_ASYMM) &&
1546 (outputType == OperandType::TENSOR_FLOAT16 ||
1547 outputType == OperandType::TENSOR_FLOAT32 ||
1548 outputType == OperandType::TENSOR_INT32 ||
1549 outputType == OperandType::TENSOR_QUANT8_ASYMM)) {
1550 inExpectedTypes = {inputType};
1551 outExpectedTypes = {outputType};
1552 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1553 } else if (inputType == OperandType::TENSOR_BOOL8 ||
1554 inputType == OperandType::TENSOR_QUANT16_ASYMM ||
1555 inputType == OperandType::TENSOR_QUANT16_SYMM ||
1556 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
1557 inputType == OperandType::TENSOR_QUANT8_SYMM) {
1558 inExpectedTypes = {inputType};
1559 outExpectedTypes = {inputType}; // Only identity CAST is supported.
1560 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1561 } else {
1562 LOG(ERROR) << "Unsupported data type for operation " << getOperationName(opType);
1563 return ANEURALNETWORKS_BAD_DATA;
1564 }
1565 // Validate that output shape is equal to input shape if dimensions
1566 // are already known.
1567 auto getNumberOfElements = [](const hardware::hidl_vec<uint32_t>& dims) {
1568 if (dims.size() == 0) {
1569 return 0;
1570 }
1571 return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
1572 };
1573 if (inputOperand.dimensions.size() != 0 && outputOperand.dimensions.size() != 0 &&
1574 getNumberOfElements(outputOperand.dimensions) != 0 &&
1575 inputOperand.dimensions != outputOperand.dimensions) {
1576 return ANEURALNETWORKS_BAD_DATA;
1577 }
1578 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1579 inExpectedTypes, outputCount, outputIndexes,
1580 outExpectedTypes);
1581 }
1582 case ANEURALNETWORKS_MEAN: {
1583 if (inputCount != 3 || outputCount != 1) {
1584 logInvalidInOutNumber(3, 1);
1585 return ANEURALNETWORKS_BAD_DATA;
1586 }
1587 const auto inputRank = operands[inputIndexes[0]].dimensions.size();
1588 if (inputRank > 4) {
1589 LOG(ERROR) << "Unsupported input tensor rank for operation "
1590 << getOperationName(opType);
1591 return ANEURALNETWORKS_BAD_DATA;
1592 }
1593 auto inputType = operands[inputIndexes[0]].type;
1594 if (inputType == OperandType::TENSOR_FLOAT32) {
1595 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1596 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1597 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1598 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1599 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
1600 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1601 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1602 } else {
1603 LOG(ERROR) << "Unsupported input tensor type for operation "
1604 << getOperationName(opType);
1605 return ANEURALNETWORKS_BAD_DATA;
1606 }
1607 std::vector<OperandType> inExpectedTypes = {inputType, OperandType::TENSOR_INT32,
1608 OperandType::INT32};
1609 std::vector<OperandType> outExpectedTypes = {inputType};
1610 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1611 inExpectedTypes, outputCount, outputIndexes,
1612 outExpectedTypes);
1613 }
1614 case ANEURALNETWORKS_ARGMAX:
1615 case ANEURALNETWORKS_ARGMIN: {
1616 if (inputCount != 2 || outputCount != 1) {
1617 logInvalidInOutNumber(2, 1);
1618 return ANEURALNETWORKS_BAD_DATA;
1619 }
1620 auto inputType = operands[inputIndexes[0]].type;
1621 std::vector<OperandType> inExpectedTypes;
1622 std::vector<OperandType> outExpectedTypes;
1623 if (inputType == OperandType::TENSOR_FLOAT16 ||
1624 inputType == OperandType::TENSOR_FLOAT32 ||
1625 inputType == OperandType::TENSOR_INT32 ||
1626 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1627 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1628 inExpectedTypes = {inputType, OperandType::INT32};
1629 outExpectedTypes = {OperandType::TENSOR_INT32};
1630 } else {
1631 LOG(ERROR) << "Unsupported input tensor type for operation "
1632 << getOperationName(opType);
1633 return ANEURALNETWORKS_BAD_DATA;
1634 }
1635 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1636 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1637 inExpectedTypes, outputCount, outputIndexes,
1638 outExpectedTypes);
1639 }
1640 case ANEURALNETWORKS_EXPAND_DIMS: {
1641 if (inputCount != 2 || outputCount != 1) {
1642 logInvalidInOutNumber(2, 1);
1643 return ANEURALNETWORKS_BAD_DATA;
1644 }
1645 auto inputType = operands[inputIndexes[0]].type;
1646 std::vector<OperandType> inExpectedTypes;
1647 std::vector<OperandType> outExpectedTypes;
1648 if (inputType == OperandType::TENSOR_FLOAT16 ||
1649 inputType == OperandType::TENSOR_FLOAT32 ||
1650 inputType == OperandType::TENSOR_INT32 ||
1651 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1652 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1653 inExpectedTypes = {inputType, OperandType::INT32};
1654 outExpectedTypes = {inputType};
1655 } else {
1656 LOG(ERROR) << "Unsupported input tensor type for operation "
1657 << getOperationName(opType);
1658 return ANEURALNETWORKS_BAD_DATA;
1659 }
1660 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1661 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1662 } else {
1663 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1664 }
1665 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1666 inExpectedTypes, outputCount, outputIndexes,
1667 outExpectedTypes);
1668 }
1669 case ANEURALNETWORKS_SPLIT: {
1670 if (inputCount != 3) {
1671 LOG(ERROR) << "Invalid number of input operands (" << inputCount << ", expected 3)"
1672 << getOperationName(opType);
1673 return ANEURALNETWORKS_BAD_DATA;
1674 }
1675 auto inputType = operands[inputIndexes[0]].type;
1676 if (inputType != OperandType::TENSOR_FLOAT16 &&
1677 inputType != OperandType::TENSOR_FLOAT32 &&
1678 inputType != OperandType::TENSOR_INT32 &&
1679 inputType != OperandType::TENSOR_QUANT8_ASYMM &&
1680 inputType != OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1681 LOG(ERROR) << "Unsupported input tensor type for operation "
1682 << getOperationName(opType);
1683 return ANEURALNETWORKS_BAD_DATA;
1684 }
1685 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1686 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1687 } else {
1688 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1689 }
1690 std::vector<OperandType> inExpectedTypes = {inputType, OperandType::INT32,
1691 OperandType::INT32};
1692 std::vector<OperandType> outExpectedTypes(outputCount, inputType);
1693 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1694 inExpectedTypes, outputCount, outputIndexes,
1695 outExpectedTypes);
1696 }
1697 case ANEURALNETWORKS_MAXIMUM:
1698 case ANEURALNETWORKS_MINIMUM: {
1699 if (inputCount != 2 || outputCount != 1) {
1700 logInvalidInOutNumber(2, 1);
1701 return ANEURALNETWORKS_BAD_DATA;
1702 }
1703 std::vector<OperandType> inExpectedTypes;
1704 std::vector<OperandType> outExpectedTypes;
1705 OperandType inputType = operands[inputIndexes[0]].type;
1706 if (inputType == OperandType::TENSOR_FLOAT16 ||
1707 inputType == OperandType::TENSOR_FLOAT32 ||
1708 inputType == OperandType::TENSOR_INT32 ||
1709 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1710 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1711 inExpectedTypes = {inputType, inputType};
1712 outExpectedTypes = {inputType};
1713 } else {
1714 LOG(ERROR) << "Unsupported input tensor type for operation "
1715 << getOperationName(opType);
1716 return ANEURALNETWORKS_BAD_DATA;
1717 }
1718 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1719 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1720 } else {
1721 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1722 }
1723 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1724 inExpectedTypes, outputCount, outputIndexes,
1725 outExpectedTypes);
1726 }
1727 case ANEURALNETWORKS_GROUPED_CONV_2D: {
1728 if ((inputCount != 12 && inputCount != 9) || outputCount != 1) {
1729 LOG(ERROR) << "Invalid number of input operands (" << inputCount
1730 << ", expected 12 or 9) or output operands (" << outputCount
1731 << ", expected 1) for operation " << getOperationName(opType);
1732 return ANEURALNETWORKS_BAD_DATA;
1733 }
1734 auto inputType = operands[inputIndexes[0]].type;
1735 auto filterType = operands[inputIndexes[1]].type;
1736 std::vector<OperandType> inExpectedTypes;
1737 std::vector<OperandType> outExpectedTypes;
1738 if (inputType == OperandType::TENSOR_FLOAT32) {
1739 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
1740 OperandType::TENSOR_FLOAT32, OperandType::INT32,
1741 OperandType::INT32, OperandType::INT32,
1742 OperandType::INT32, OperandType::INT32};
1743 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1744 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1745 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
1746 OperandType::TENSOR_FLOAT16, OperandType::INT32,
1747 OperandType::INT32, OperandType::INT32,
1748 OperandType::INT32, OperandType::INT32};
1749 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1750 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1751 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1752 if (filterType != inputType &&
1753 filterType != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
1754 LOG(ERROR) << "Unsupported filter tensor type for operation "
1755 << getOperationName(opType);
1756 return ANEURALNETWORKS_BAD_DATA;
1757 }
1758
1759 if (filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL &&
1760 operands[inputIndexes[1]].extraParams.channelQuant().channelDim != 0) {
1761 LOG(ERROR) << "Unsupported filter tensor channel dimension for operation "
1762 << getOperationName(opType);
1763 return ANEURALNETWORKS_BAD_DATA;
1764 }
1765
1766 inExpectedTypes = {
1767 inputType, filterType, OperandType::TENSOR_INT32,
1768 OperandType::INT32, OperandType::INT32, OperandType::INT32,
1769 OperandType::INT32, OperandType::INT32};
1770 outExpectedTypes = {inputType};
1771 } else {
1772 LOG(ERROR) << "Unsupported input tensor type for operation "
1773 << getOperationName(opType);
1774 return ANEURALNETWORKS_BAD_DATA;
1775 }
1776
1777 if (inputCount == 12) {
1778 std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
1779 inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
1780 explicitScalarTypes.end());
1781 }
1782 inExpectedTypes.push_back(OperandType::BOOL);
1783 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1784 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1785 } else {
1786 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1787 }
1788 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1789 inExpectedTypes, outputCount, outputIndexes,
1790 outExpectedTypes);
1791 }
1792 case ANEURALNETWORKS_TILE: {
1793 if (inputCount != 2 || outputCount != 1) {
1794 logInvalidInOutNumber(2, 1);
1795 return ANEURALNETWORKS_BAD_DATA;
1796 }
1797 auto inputType = operands[inputIndexes[0]].type;
1798 std::vector<OperandType> inExpectedTypes;
1799 std::vector<OperandType> outExpectedTypes;
1800 if (inputType == OperandType::TENSOR_FLOAT16 ||
1801 inputType == OperandType::TENSOR_FLOAT32 ||
1802 inputType == OperandType::TENSOR_INT32 ||
1803 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1804 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1805 inExpectedTypes = {inputType, OperandType::TENSOR_INT32};
1806 outExpectedTypes = {inputType};
1807 } else {
1808 LOG(ERROR) << "Unsupported input tensor type for operation "
1809 << getOperationName(opType);
1810 return ANEURALNETWORKS_BAD_DATA;
1811 }
1812 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1813 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1814 } else {
1815 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1816 }
1817 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1818 inExpectedTypes, outputCount, outputIndexes,
1819 outExpectedTypes);
1820 }
1821 case ANEURALNETWORKS_POW: {
1822 if (inputCount != 2 || outputCount != 1) {
1823 logInvalidInOutNumber(2, 1);
1824 return ANEURALNETWORKS_BAD_DATA;
1825 }
1826 auto inputType = operands[inputIndexes[0]].type;
1827 std::vector<OperandType> inExpectedTypes;
1828 std::vector<OperandType> outExpectedTypes;
1829 if (inputType == OperandType::TENSOR_FLOAT16 ||
1830 inputType == OperandType::TENSOR_FLOAT32) {
1831 inExpectedTypes = {inputType, inputType};
1832 outExpectedTypes = {inputType};
1833 } else {
1834 LOG(ERROR) << "Unsupported input tensor type for operation "
1835 << getOperationName(opType);
1836 return ANEURALNETWORKS_BAD_DATA;
1837 }
1838 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1839 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1840 } else {
1841 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
1842 }
1843 return validateOperationOperandTypes(operands, inputCount, inputIndexes,
1844 inExpectedTypes, outputCount, outputIndexes,
1845 outExpectedTypes);
1846 }
1847 case ANEURALNETWORKS_IF: {
1848 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1849 return validateIfOperation(inputCount, inputIndexes, outputCount, outputIndexes,
1850 operands, helper)
1851 ? ANEURALNETWORKS_NO_ERROR
1852 : ANEURALNETWORKS_BAD_DATA;
1853 }
1854 case ANEURALNETWORKS_WHILE: {
1855 NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
1856 return validateWhileOperation(inputCount, inputIndexes, outputCount, outputIndexes,
1857 operands, helper)
1858 ? ANEURALNETWORKS_NO_ERROR
1859 : ANEURALNETWORKS_BAD_DATA;
1860 }
1861 default: {
1862 const OperationRegistration* operationRegistration =
1863 BuiltinOperationResolver::get()->findOperation(
1864 static_cast<OperationType>(opType));
1865 if (operationRegistration == nullptr) {
1866 if (0 <= opType && opType < kNumberOfOperationTypes) {
1867 LOG(ERROR) << getOperationName(opType) << " not registered";
1868 } else {
1869 LOG(ERROR) << "Operation type " << opType << " out of the range [0, "
1870 << kNumberOfOperationTypes << ")";
1871 }
1872 return ANEURALNETWORKS_UNEXPECTED_NULL;
1873 }
1874 if (operationRegistration->validate == nullptr) {
1875 LOG(ERROR) << "Incomplete operation registration: " << getOperationName(opType);
1876 return ANEURALNETWORKS_UNEXPECTED_NULL;
1877 }
1878 OperationValidationContext context(operationRegistration->name, inputCount,
1879 inputIndexes, outputCount, outputIndexes,
1880 operands.data(), halVersion);
1881 if (!operationRegistration->validate(&context)) {
1882 LOG(ERROR) << "Validation failed for operation " << getOperationName(opType);
1883 return ANEURALNETWORKS_BAD_DATA;
1884 }
1885 return ANEURALNETWORKS_NO_ERROR;
1886 }
1887 }
1888 }
1889
convertResultCodeToErrorStatus(int resultCode)1890 ErrorStatus convertResultCodeToErrorStatus(int resultCode) {
1891 switch (resultCode) {
1892 case ANEURALNETWORKS_NO_ERROR:
1893 return ErrorStatus::NONE;
1894
1895 case ANEURALNETWORKS_BAD_DATA:
1896 case ANEURALNETWORKS_UNEXPECTED_NULL:
1897 return ErrorStatus::INVALID_ARGUMENT;
1898
1899 case ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE:
1900 return ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
1901
1902 case ANEURALNETWORKS_UNAVAILABLE_DEVICE:
1903 return ErrorStatus::DEVICE_UNAVAILABLE;
1904
1905 case ANEURALNETWORKS_BAD_STATE:
1906 case ANEURALNETWORKS_INCOMPLETE:
1907 case ANEURALNETWORKS_OP_FAILED:
1908 case ANEURALNETWORKS_OUT_OF_MEMORY:
1909 case ANEURALNETWORKS_UNMAPPABLE:
1910 case ANEURALNETWORKS_DEAD_OBJECT:
1911 return ErrorStatus::GENERAL_FAILURE;
1912
1913 case ANEURALNETWORKS_MISSED_DEADLINE_TRANSIENT:
1914 return ErrorStatus::MISSED_DEADLINE_TRANSIENT;
1915 case ANEURALNETWORKS_MISSED_DEADLINE_PERSISTENT:
1916 return ErrorStatus::MISSED_DEADLINE_PERSISTENT;
1917 case ANEURALNETWORKS_RESOURCE_EXHAUSTED_TRANSIENT:
1918 return ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT;
1919 case ANEURALNETWORKS_RESOURCE_EXHAUSTED_PERSISTENT:
1920 return ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT;
1921 }
1922 LOG(ERROR) << "Unknown result code " << resultCode << " mapped to ErrorStatus::GENERAL_FAILURE";
1923 return ErrorStatus::GENERAL_FAILURE;
1924 }
1925
convertErrorStatusToResultCode(ErrorStatus status)1926 int convertErrorStatusToResultCode(ErrorStatus status) {
1927 switch (status) {
1928 case ErrorStatus::NONE:
1929 return ANEURALNETWORKS_NO_ERROR;
1930 case ErrorStatus::DEVICE_UNAVAILABLE:
1931 return ANEURALNETWORKS_UNAVAILABLE_DEVICE;
1932 case ErrorStatus::GENERAL_FAILURE:
1933 return ANEURALNETWORKS_OP_FAILED;
1934 case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
1935 return ANEURALNETWORKS_OUTPUT_INSUFFICIENT_SIZE;
1936 case ErrorStatus::INVALID_ARGUMENT:
1937 return ANEURALNETWORKS_BAD_DATA;
1938 case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
1939 return ANEURALNETWORKS_MISSED_DEADLINE_TRANSIENT;
1940 case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
1941 return ANEURALNETWORKS_MISSED_DEADLINE_PERSISTENT;
1942 case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
1943 return ANEURALNETWORKS_RESOURCE_EXHAUSTED_TRANSIENT;
1944 case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
1945 return ANEURALNETWORKS_RESOURCE_EXHAUSTED_PERSISTENT;
1946 }
1947 LOG(ERROR) << "Unknown ErrorStatus " << toString(status)
1948 << " mapped to ANEURALNETWORKS_OP_FAILED";
1949 return ANEURALNETWORKS_OP_FAILED;
1950 }
1951
getExecutionResult(ErrorStatus status,std::vector<OutputShape> outputShapes,Timing timing)1952 std::tuple<int, std::vector<OutputShape>, Timing> getExecutionResult(
1953 ErrorStatus status, std::vector<OutputShape> outputShapes, Timing timing) {
1954 constexpr Timing kNoTiming = {std::numeric_limits<uint64_t>::max(),
1955 std::numeric_limits<uint64_t>::max()};
1956 const int n = convertErrorStatusToResultCode(status);
1957 if (status != ErrorStatus::NONE && status != ErrorStatus::OUTPUT_INSUFFICIENT_SIZE &&
1958 !outputShapes.empty()) {
1959 LOG(ERROR) << "The driver returned OutputShapes when it shouldn't.";
1960 outputShapes.clear();
1961 }
1962 if (status != ErrorStatus::NONE && timing != kNoTiming) {
1963 LOG(ERROR) << "The driver returned Timing when it shouldn't.";
1964 timing = kNoTiming;
1965 }
1966 return {n, std::move(outputShapes), timing};
1967 }
1968
combineDimensions(const std::vector<uint32_t> & lhs,const std::vector<uint32_t> & rhs)1969 std::optional<std::vector<uint32_t>> combineDimensions(const std::vector<uint32_t>& lhs,
1970 const std::vector<uint32_t>& rhs) {
1971 if (rhs.empty()) return lhs;
1972 if (lhs.empty()) return rhs;
1973 if (lhs.size() != rhs.size()) {
1974 LOG(ERROR) << "Incompatible ranks: " << toString(lhs) << " and " << toString(rhs);
1975 return std::nullopt;
1976 }
1977 std::vector<uint32_t> combined = lhs;
1978 for (uint32_t i = 0; i < lhs.size(); i++) {
1979 if (lhs[i] == 0) {
1980 combined[i] = rhs[i];
1981 } else if (rhs[i] != 0 && lhs[i] != rhs[i]) {
1982 LOG(ERROR) << "Incompatible dimensions: " << toString(lhs) << " and " << toString(rhs);
1983 return std::nullopt;
1984 }
1985 }
1986 return combined;
1987 }
1988
1989 // Capabilities::operandPerformance utilities.
1990 // The field Capabilities::operandPerformance is a vector sorted by the field
1991 // Capabilities::OperandPerformance::type.
1992
1993 template <HalVersion version>
nonExtensionOperandPerformance(PerformanceInfo perf)1994 hidl_vec<VersionedOperandPerformance<version>> nonExtensionOperandPerformance(
1995 PerformanceInfo perf) {
1996 using OpPerf = VersionedOperandPerformance<version>;
1997
1998 // Note: range presents enumerators in declaration order, not in numerical order.
1999 static constexpr hidl_enum_range<VersionedOperandType<version>> kOperandTypeRange;
2000
2001 std::vector<OpPerf> ret;
2002 ret.reserve(kOperandTypeRange.end() - kOperandTypeRange.begin());
2003 for (VersionedOperandType<version> type : kOperandTypeRange) {
2004 if (static_cast<OperandType>(type) != OperandType::SUBGRAPH) {
2005 ret.push_back(OpPerf{type, perf});
2006 }
2007 }
2008 std::sort(ret.begin(), ret.end(),
2009 [](const OpPerf& a, const OpPerf& b) { return a.type < b.type; });
2010
2011 return ret;
2012 }
2013
2014 template hal::hidl_vec<V1_2::Capabilities::OperandPerformance>
2015 nonExtensionOperandPerformance<HalVersion::V1_2>(PerformanceInfo perf);
2016 template hal::hidl_vec<V1_3::Capabilities::OperandPerformance>
2017 nonExtensionOperandPerformance<HalVersion::V1_3>(PerformanceInfo perf);
2018
2019 template <HalVersion version>
update(hal::hidl_vec<VersionedOperandPerformance<version>> * operandPerformance,VersionedOperandType<version> type,hal::PerformanceInfo perf)2020 void update(hal::hidl_vec<VersionedOperandPerformance<version>>* operandPerformance,
2021 VersionedOperandType<version> type, hal::PerformanceInfo perf) {
2022 CHECK(operandPerformance != nullptr);
2023 const auto it =
2024 std::lower_bound(operandPerformance->begin(), operandPerformance->end(), type,
2025 [](const VersionedOperandPerformance<version>& perf,
2026 VersionedOperandType<version> type) { return perf.type < type; });
2027 CHECK(it != operandPerformance->end())
2028 << toString(type) << " not in " << toString(*operandPerformance);
2029 it->info = perf;
2030 }
2031
update(hidl_vec<V1_2::Capabilities::OperandPerformance> * operandPerformance,V1_2::OperandType type,PerformanceInfo perf)2032 void update(hidl_vec<V1_2::Capabilities::OperandPerformance>* operandPerformance,
2033 V1_2::OperandType type, PerformanceInfo perf) {
2034 update<HalVersion::V1_2>(operandPerformance, type, perf);
2035 }
update(hidl_vec<V1_3::Capabilities::OperandPerformance> * operandPerformance,V1_3::OperandType type,PerformanceInfo perf)2036 void update(hidl_vec<V1_3::Capabilities::OperandPerformance>* operandPerformance,
2037 V1_3::OperandType type, PerformanceInfo perf) {
2038 update<HalVersion::V1_3>(operandPerformance, type, perf);
2039 }
2040
2041 template <HalVersion version>
lookup(const hidl_vec<VersionedOperandPerformance<version>> & operandPerformance,VersionedOperandType<version> type)2042 PerformanceInfo lookup(const hidl_vec<VersionedOperandPerformance<version>>& operandPerformance,
2043 VersionedOperandType<version> type) {
2044 const auto it = std::lower_bound(operandPerformance.begin(), operandPerformance.end(), type,
2045 [](const VersionedOperandPerformance<version>& perf,
2046 VersionedOperandType<version> type) {
2047 return static_cast<OperandType>(perf.type) <
2048 static_cast<OperandType>(type);
2049 });
2050 if (it == operandPerformance.end()) {
2051 LOG(WARNING) << "No PerformanceInfo for " << toString(type);
2052 return kNoPerformanceInfo;
2053 } else {
2054 return it->info;
2055 }
2056 }
2057
lookup(const hidl_vec<V1_2::Capabilities::OperandPerformance> & operandPerformance,V1_2::OperandType type)2058 PerformanceInfo lookup(const hidl_vec<V1_2::Capabilities::OperandPerformance>& operandPerformance,
2059 V1_2::OperandType type) {
2060 return lookup<HalVersion::V1_2>(operandPerformance, type);
2061 }
lookup(const hidl_vec<V1_3::Capabilities::OperandPerformance> & operandPerformance,V1_3::OperandType type)2062 PerformanceInfo lookup(const hidl_vec<V1_3::Capabilities::OperandPerformance>& operandPerformance,
2063 V1_3::OperandType type) {
2064 CHECK(type != V1_3::OperandType::SUBGRAPH)
2065 << "Use Capabilities::ifPerformance or Capabilities::whilePerformance";
2066 return lookup<HalVersion::V1_3>(operandPerformance, type);
2067 }
2068
2069 // Versioning
2070
2071 // In Android P, most data types are treated as having the same performance as TENSOR_QUANT8_ASYMM.
2072 // This array must be in sorted order.
2073 static const OperandType kQuantized8PerformanceConsistentWithP[] = {
2074 OperandType::INT32, OperandType::UINT32, OperandType::TENSOR_INT32, OperandType::OEM,
2075 OperandType::TENSOR_OEM_BYTE};
2076
isQuantized8PerformanceConsistentWithP(const V1_2::Capabilities & capabilities)2077 static bool isQuantized8PerformanceConsistentWithP(const V1_2::Capabilities& capabilities) {
2078 const PerformanceInfo quantized8Performance =
2079 lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_QUANT8_ASYMM);
2080 return std::all_of(std::begin(kQuantized8PerformanceConsistentWithP),
2081 std::end(kQuantized8PerformanceConsistentWithP),
2082 [quantized8Performance, &capabilities](OperandType type) {
2083 return quantized8Performance ==
2084 lookup(capabilities.operandPerformance,
2085 static_cast<V1_2::OperandType>(type));
2086 });
2087 }
2088
isQuantized8PerformanceConsistentWithP(const V1_3::Capabilities & capabilities)2089 static bool isQuantized8PerformanceConsistentWithP(const V1_3::Capabilities& capabilities) {
2090 const PerformanceInfo quantized8Performance =
2091 lookup(capabilities.operandPerformance, OperandType::TENSOR_QUANT8_ASYMM);
2092 return std::all_of(std::begin(kQuantized8PerformanceConsistentWithP),
2093 std::end(kQuantized8PerformanceConsistentWithP),
2094 [quantized8Performance, &capabilities](OperandType type) {
2095 return quantized8Performance ==
2096 lookup(capabilities.operandPerformance, type);
2097 });
2098 }
2099
makeQuantized8PerformanceConsistentWithP(PerformanceInfo quantized8Performance)2100 static hidl_vec<V1_2::Capabilities::OperandPerformance> makeQuantized8PerformanceConsistentWithP(
2101 PerformanceInfo quantized8Performance) {
2102 hidl_vec<V1_2::Capabilities::OperandPerformance> ret(
2103 std::size(kQuantized8PerformanceConsistentWithP));
2104 std::transform(
2105 std::begin(kQuantized8PerformanceConsistentWithP),
2106 std::end(kQuantized8PerformanceConsistentWithP), ret.begin(),
2107 [quantized8Performance](OperandType type) -> V1_2::Capabilities::OperandPerformance {
2108 return {static_cast<V1_2::OperandType>(type), quantized8Performance};
2109 });
2110 return ret;
2111 }
2112
compliantWithV1_0(const V1_0::Capabilities &)2113 bool compliantWithV1_0(const V1_0::Capabilities&) {
2114 return true;
2115 }
2116
compliantWithV1_0(const V1_1::Capabilities & capabilities)2117 bool compliantWithV1_0(const V1_1::Capabilities& capabilities) {
2118 return capabilities.relaxedFloat32toFloat16Performance == capabilities.float32Performance;
2119 }
2120
compliantWithV1_0(const V1_2::Capabilities & capabilities)2121 bool compliantWithV1_0(const V1_2::Capabilities& capabilities) {
2122 const PerformanceInfo perfTensorFloat32 =
2123 lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT32);
2124 const PerformanceInfo perfFloat32 =
2125 lookup(capabilities.operandPerformance, V1_2::OperandType::FLOAT32);
2126 if (perfTensorFloat32 != perfFloat32 ||
2127 perfTensorFloat32 != capabilities.relaxedFloat32toFloat16PerformanceTensor ||
2128 perfFloat32 != capabilities.relaxedFloat32toFloat16PerformanceScalar) {
2129 return false;
2130 }
2131
2132 return isQuantized8PerformanceConsistentWithP(capabilities);
2133 }
2134
compliantWithV1_0(const V1_3::Capabilities & capabilities)2135 bool compliantWithV1_0(const V1_3::Capabilities& capabilities) {
2136 const PerformanceInfo perfTensorFloat32 =
2137 lookup(capabilities.operandPerformance, OperandType::TENSOR_FLOAT32);
2138 const PerformanceInfo perfFloat32 =
2139 lookup(capabilities.operandPerformance, OperandType::FLOAT32);
2140 if (perfTensorFloat32 != perfFloat32 ||
2141 perfTensorFloat32 != capabilities.relaxedFloat32toFloat16PerformanceTensor ||
2142 perfFloat32 != capabilities.relaxedFloat32toFloat16PerformanceScalar) {
2143 return false;
2144 }
2145
2146 return isQuantized8PerformanceConsistentWithP(capabilities);
2147 }
2148
compliantWithV1_1(const V1_0::Capabilities &)2149 bool compliantWithV1_1(const V1_0::Capabilities&) {
2150 return true;
2151 }
2152
compliantWithV1_1(const V1_1::Capabilities &)2153 bool compliantWithV1_1(const V1_1::Capabilities&) {
2154 return true;
2155 }
2156
compliantWithV1_1(const V1_2::Capabilities & capabilities)2157 bool compliantWithV1_1(const V1_2::Capabilities& capabilities) {
2158 if ((capabilities.relaxedFloat32toFloat16PerformanceTensor !=
2159 capabilities.relaxedFloat32toFloat16PerformanceScalar) ||
2160 (lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT32) !=
2161 lookup(capabilities.operandPerformance, V1_2::OperandType::FLOAT32))) {
2162 return false;
2163 }
2164
2165 return isQuantized8PerformanceConsistentWithP(capabilities);
2166 }
2167
compliantWithV1_1(const V1_3::Capabilities & capabilities)2168 bool compliantWithV1_1(const V1_3::Capabilities& capabilities) {
2169 if ((capabilities.relaxedFloat32toFloat16PerformanceTensor !=
2170 capabilities.relaxedFloat32toFloat16PerformanceScalar) ||
2171 (lookup(capabilities.operandPerformance, OperandType::TENSOR_FLOAT32) !=
2172 lookup(capabilities.operandPerformance, OperandType::FLOAT32))) {
2173 return false;
2174 }
2175
2176 return isQuantized8PerformanceConsistentWithP(capabilities);
2177 }
2178
compliantWithV1_2(const V1_0::Capabilities &)2179 bool compliantWithV1_2(const V1_0::Capabilities&) {
2180 return true;
2181 }
2182
compliantWithV1_2(const V1_1::Capabilities &)2183 bool compliantWithV1_2(const V1_1::Capabilities&) {
2184 return true;
2185 }
2186
compliantWithV1_2(const V1_2::Capabilities &)2187 bool compliantWithV1_2(const V1_2::Capabilities&) {
2188 return true;
2189 }
2190
compliantWithV1_2(const V1_3::Capabilities &)2191 bool compliantWithV1_2(const V1_3::Capabilities&) {
2192 return true;
2193 }
2194
compliantWithV1_3(const V1_0::Capabilities &)2195 bool compliantWithV1_3(const V1_0::Capabilities&) {
2196 return true;
2197 }
2198
compliantWithV1_3(const V1_1::Capabilities &)2199 bool compliantWithV1_3(const V1_1::Capabilities&) {
2200 return true;
2201 }
2202
compliantWithV1_3(const V1_2::Capabilities &)2203 bool compliantWithV1_3(const V1_2::Capabilities&) {
2204 return true;
2205 }
2206
compliantWithV1_3(const V1_3::Capabilities &)2207 bool compliantWithV1_3(const V1_3::Capabilities&) {
2208 return true;
2209 }
2210
convertToV1_0(V1_0::ErrorStatus status)2211 V1_0::ErrorStatus convertToV1_0(V1_0::ErrorStatus status) {
2212 return status;
2213 }
2214
convertToV1_0(V1_3::ErrorStatus status)2215 V1_0::ErrorStatus convertToV1_0(V1_3::ErrorStatus status) {
2216 switch (status) {
2217 case V1_3::ErrorStatus::NONE:
2218 return V1_0::ErrorStatus::NONE;
2219 case V1_3::ErrorStatus::DEVICE_UNAVAILABLE:
2220 return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
2221 case V1_3::ErrorStatus::GENERAL_FAILURE:
2222 return V1_0::ErrorStatus::GENERAL_FAILURE;
2223 case V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
2224 return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
2225 case V1_3::ErrorStatus::INVALID_ARGUMENT:
2226 return V1_0::ErrorStatus::INVALID_ARGUMENT;
2227 case V1_3::ErrorStatus::MISSED_DEADLINE_TRANSIENT:
2228 return V1_0::ErrorStatus::GENERAL_FAILURE;
2229 case V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT:
2230 return V1_0::ErrorStatus::GENERAL_FAILURE;
2231 case V1_3::ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
2232 return V1_0::ErrorStatus::GENERAL_FAILURE;
2233 case V1_3::ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
2234 return V1_0::ErrorStatus::GENERAL_FAILURE;
2235 }
2236 LOG(ERROR) << "Unknown ErrorStatus: " << toString(status) << " mapped to GENERAL_FAILURE";
2237 return V1_0::ErrorStatus::GENERAL_FAILURE;
2238 }
2239
convertToV1_3(V1_0::ErrorStatus status)2240 V1_3::ErrorStatus convertToV1_3(V1_0::ErrorStatus status) {
2241 return static_cast<V1_3::ErrorStatus>(status);
2242 }
2243
convertToV1_3(V1_3::ErrorStatus status)2244 V1_3::ErrorStatus convertToV1_3(V1_3::ErrorStatus status) {
2245 return status;
2246 }
2247
uncheckedConvertToV1_0(V1_1::OperationType type)2248 static V1_0::OperationType uncheckedConvertToV1_0(V1_1::OperationType type) {
2249 return static_cast<V1_0::OperationType>(type);
2250 }
2251
uncheckedConvertToV1_0(V1_2::OperationType type)2252 static V1_0::OperationType uncheckedConvertToV1_0(V1_2::OperationType type) {
2253 return static_cast<V1_0::OperationType>(type);
2254 }
2255
uncheckedConvertToV1_0(V1_3::OperationType type)2256 V1_0::OperationType uncheckedConvertToV1_0(V1_3::OperationType type) {
2257 return static_cast<V1_0::OperationType>(type);
2258 }
2259
convertToV1_1(V1_0::OperationType type)2260 static V1_1::OperationType convertToV1_1(V1_0::OperationType type) {
2261 return static_cast<V1_1::OperationType>(type);
2262 }
2263
uncheckedConvertToV1_1(V1_2::OperationType type)2264 static V1_1::OperationType uncheckedConvertToV1_1(V1_2::OperationType type) {
2265 return static_cast<V1_1::OperationType>(type);
2266 }
2267
uncheckedConvertToV1_1(V1_3::OperationType type)2268 V1_1::OperationType uncheckedConvertToV1_1(V1_3::OperationType type) {
2269 return static_cast<V1_1::OperationType>(type);
2270 }
2271
convertToV1_2(V1_0::OperationType type)2272 static V1_2::OperationType convertToV1_2(V1_0::OperationType type) {
2273 return static_cast<V1_2::OperationType>(type);
2274 }
2275
convertToV1_2(V1_1::OperationType type)2276 static V1_2::OperationType convertToV1_2(V1_1::OperationType type) {
2277 return static_cast<V1_2::OperationType>(type);
2278 }
2279
uncheckedConvertToV1_2(V1_3::OperationType type)2280 V1_2::OperationType uncheckedConvertToV1_2(V1_3::OperationType type) {
2281 return static_cast<V1_2::OperationType>(type);
2282 }
2283
convertToV1_3(V1_0::OperationType type)2284 static V1_3::OperationType convertToV1_3(V1_0::OperationType type) {
2285 return static_cast<V1_3::OperationType>(type);
2286 }
2287
convertToV1_3(V1_1::OperationType type)2288 static V1_3::OperationType convertToV1_3(V1_1::OperationType type) {
2289 return static_cast<V1_3::OperationType>(type);
2290 }
2291
convertToV1_3(V1_2::OperationType type)2292 static V1_3::OperationType convertToV1_3(V1_2::OperationType type) {
2293 return static_cast<V1_3::OperationType>(type);
2294 }
2295
convertToV1_0(const V1_0::Capabilities & capabilities)2296 V1_0::Capabilities convertToV1_0(const V1_0::Capabilities& capabilities) {
2297 return capabilities;
2298 }
2299
convertToV1_0(const V1_1::Capabilities & capabilities)2300 V1_0::Capabilities convertToV1_0(const V1_1::Capabilities& capabilities) {
2301 if (!compliantWithV1_0(capabilities)) {
2302 LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2303 << " from V1_1::Capabilities to V1_0::Capabilities";
2304 }
2305 return {.float32Performance = capabilities.float32Performance,
2306 .quantized8Performance = capabilities.quantized8Performance};
2307 }
2308
convertToV1_0(const V1_2::Capabilities & capabilities)2309 V1_0::Capabilities convertToV1_0(const V1_2::Capabilities& capabilities) {
2310 if (!compliantWithV1_0(capabilities)) {
2311 LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2312 << " from V1_2::Capabilities to V1_0::Capabilities";
2313 }
2314 return {.float32Performance =
2315 lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT32),
2316 .quantized8Performance = lookup(capabilities.operandPerformance,
2317 V1_2::OperandType::TENSOR_QUANT8_ASYMM)};
2318 }
2319
convertToV1_0(const V1_3::Capabilities & capabilities)2320 V1_0::Capabilities convertToV1_0(const V1_3::Capabilities& capabilities) {
2321 if (!compliantWithV1_0(capabilities)) {
2322 LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2323 << " from V1_3::Capabilities to V1_0::Capabilities";
2324 }
2325 return {.float32Performance =
2326 lookup(capabilities.operandPerformance, OperandType::TENSOR_FLOAT32),
2327 .quantized8Performance =
2328 lookup(capabilities.operandPerformance, OperandType::TENSOR_QUANT8_ASYMM)};
2329 }
2330
convertToV1_1(const V1_0::Capabilities & capabilities)2331 V1_1::Capabilities convertToV1_1(const V1_0::Capabilities& capabilities) {
2332 return {.float32Performance = capabilities.float32Performance,
2333 .quantized8Performance = capabilities.quantized8Performance,
2334 .relaxedFloat32toFloat16Performance = capabilities.float32Performance};
2335 }
2336
convertToV1_1(const V1_1::Capabilities & capabilities)2337 V1_1::Capabilities convertToV1_1(const V1_1::Capabilities& capabilities) {
2338 return capabilities;
2339 }
2340
convertToV1_1(const V1_2::Capabilities & capabilities)2341 V1_1::Capabilities convertToV1_1(const V1_2::Capabilities& capabilities) {
2342 if (!compliantWithV1_1(capabilities)) {
2343 LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2344 << " from V1_2::Capabilities to V1_1::Capabilities";
2345 }
2346 return {.float32Performance =
2347 lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_FLOAT32),
2348 .quantized8Performance =
2349 lookup(capabilities.operandPerformance, V1_2::OperandType::TENSOR_QUANT8_ASYMM),
2350 .relaxedFloat32toFloat16Performance =
2351 capabilities.relaxedFloat32toFloat16PerformanceTensor};
2352 }
2353
convertToV1_1(const V1_3::Capabilities & capabilities)2354 V1_1::Capabilities convertToV1_1(const V1_3::Capabilities& capabilities) {
2355 if (!compliantWithV1_1(capabilities)) {
2356 LOG(ERROR) << "Upcasting non-compliant capabilities " << toString(capabilities)
2357 << " from V1_3::Capabilities to V1_1::Capabilities";
2358 }
2359 return {.float32Performance =
2360 lookup(capabilities.operandPerformance, OperandType::TENSOR_FLOAT32),
2361 .quantized8Performance =
2362 lookup(capabilities.operandPerformance, OperandType::TENSOR_QUANT8_ASYMM),
2363 .relaxedFloat32toFloat16Performance =
2364 capabilities.relaxedFloat32toFloat16PerformanceTensor};
2365 }
2366
convertToV1_2(const V1_0::Capabilities & capabilities)2367 V1_2::Capabilities convertToV1_2(const V1_0::Capabilities& capabilities) {
2368 V1_2::Capabilities ret = {
2369 .relaxedFloat32toFloat16PerformanceScalar = capabilities.float32Performance,
2370 .relaxedFloat32toFloat16PerformanceTensor = capabilities.float32Performance,
2371 .operandPerformance =
2372 makeQuantized8PerformanceConsistentWithP(capabilities.quantized8Performance)};
2373 auto& opPerf = ret.operandPerformance;
2374 opPerf.resize(opPerf.size() + 2);
2375 opPerf[opPerf.size() - 2] = {V1_2::OperandType::TENSOR_FLOAT32,
2376 capabilities.float32Performance};
2377 opPerf[opPerf.size() - 1] = {V1_2::OperandType::FLOAT32, capabilities.float32Performance};
2378 using OperandPerformance = V1_2::Capabilities::OperandPerformance;
2379 std::sort(opPerf.begin(), opPerf.end(),
2380 [](const OperandPerformance& a, const OperandPerformance& b) {
2381 return a.type < b.type;
2382 });
2383 return ret;
2384 }
2385
convertToV1_2(const V1_1::Capabilities & capabilities)2386 V1_2::Capabilities convertToV1_2(const V1_1::Capabilities& capabilities) {
2387 V1_2::Capabilities ret = {.relaxedFloat32toFloat16PerformanceScalar =
2388 capabilities.relaxedFloat32toFloat16Performance,
2389 .relaxedFloat32toFloat16PerformanceTensor =
2390 capabilities.relaxedFloat32toFloat16Performance,
2391 .operandPerformance = makeQuantized8PerformanceConsistentWithP(
2392 capabilities.quantized8Performance)};
2393 auto& opPerf = ret.operandPerformance;
2394 opPerf.resize(opPerf.size() + 2);
2395 opPerf[opPerf.size() - 2] = {V1_2::OperandType::TENSOR_FLOAT32,
2396 capabilities.float32Performance};
2397 opPerf[opPerf.size() - 1] = {V1_2::OperandType::FLOAT32, capabilities.float32Performance};
2398 using OperandPerformance = V1_2::Capabilities::OperandPerformance;
2399 std::sort(opPerf.begin(), opPerf.end(),
2400 [](const OperandPerformance& a, const OperandPerformance& b) {
2401 return a.type < b.type;
2402 });
2403 return ret;
2404 }
2405
convertToV1_2(const V1_2::Capabilities & capabilities)2406 V1_2::Capabilities convertToV1_2(const V1_2::Capabilities& capabilities) {
2407 return capabilities;
2408 }
2409
convertToV1_2(const V1_3::Capabilities & capabilities)2410 V1_2::Capabilities convertToV1_2(const V1_3::Capabilities& capabilities) {
2411 V1_2::Capabilities ret = {
2412 .relaxedFloat32toFloat16PerformanceScalar =
2413 capabilities.relaxedFloat32toFloat16PerformanceScalar,
2414 .relaxedFloat32toFloat16PerformanceTensor =
2415 capabilities.relaxedFloat32toFloat16PerformanceTensor,
2416 };
2417 const auto& inputOpPerf = capabilities.operandPerformance;
2418 hidl_vec<V1_3::Capabilities::OperandPerformance> opPerfSupported;
2419 opPerfSupported.resize(inputOpPerf.size());
2420 auto last =
2421 std::copy_if(inputOpPerf.begin(), inputOpPerf.end(), opPerfSupported.begin(),
2422 [](V1_3::Capabilities::OperandPerformance opPerf) {
2423 return validOperandType(static_cast<V1_2::OperandType>(opPerf.type));
2424 });
2425 opPerfSupported.resize(std::distance(opPerfSupported.begin(), last));
2426
2427 auto& convertedOpPerf = ret.operandPerformance;
2428 convertedOpPerf.resize(opPerfSupported.size());
2429 std::transform(opPerfSupported.begin(), opPerfSupported.end(), convertedOpPerf.begin(),
2430 [](V1_3::Capabilities::OperandPerformance opPerf) {
2431 return V1_2::Capabilities::OperandPerformance{
2432 static_cast<V1_2::OperandType>(opPerf.type), opPerf.info};
2433 });
2434 return ret;
2435 }
2436
convertToV1_3(const V1_0::Capabilities & capabilities)2437 V1_3::Capabilities convertToV1_3(const V1_0::Capabilities& capabilities) {
2438 return convertToV1_3(convertToV1_2(capabilities));
2439 }
2440
convertToV1_3(const V1_1::Capabilities & capabilities)2441 V1_3::Capabilities convertToV1_3(const V1_1::Capabilities& capabilities) {
2442 return convertToV1_3(convertToV1_2(capabilities));
2443 }
2444
convertToV1_3(const V1_2::Capabilities & capabilities)2445 V1_3::Capabilities convertToV1_3(const V1_2::Capabilities& capabilities) {
2446 V1_3::Capabilities ret = {
2447 .relaxedFloat32toFloat16PerformanceScalar =
2448 capabilities.relaxedFloat32toFloat16PerformanceScalar,
2449 .relaxedFloat32toFloat16PerformanceTensor =
2450 capabilities.relaxedFloat32toFloat16PerformanceTensor,
2451 .ifPerformance = kNoPerformanceInfo,
2452 .whilePerformance = kNoPerformanceInfo,
2453 };
2454 auto& opPerf = ret.operandPerformance;
2455 opPerf.resize(capabilities.operandPerformance.size());
2456 std::transform(capabilities.operandPerformance.begin(), capabilities.operandPerformance.end(),
2457 opPerf.begin(), [](V1_2::Capabilities::OperandPerformance opPerf) {
2458 return V1_3::Capabilities::OperandPerformance{
2459 static_cast<V1_3::OperandType>(opPerf.type), opPerf.info};
2460 });
2461 return ret;
2462 }
2463
convertToV1_3(const V1_3::Capabilities & capabilities)2464 V1_3::Capabilities convertToV1_3(const V1_3::Capabilities& capabilities) {
2465 return capabilities;
2466 }
2467
uncheckedConvertToV1_0(const V1_1::Operation & operation)2468 static V1_0::Operation uncheckedConvertToV1_0(const V1_1::Operation& operation) {
2469 return {.type = uncheckedConvertToV1_0(operation.type),
2470 .inputs = operation.inputs,
2471 .outputs = operation.outputs};
2472 }
2473
convertToV1_1(const V1_0::Operation & operation)2474 static V1_1::Operation convertToV1_1(const V1_0::Operation& operation) {
2475 return {.type = convertToV1_1(operation.type),
2476 .inputs = operation.inputs,
2477 .outputs = operation.outputs};
2478 }
2479
uncheckedConvertToV1_0(const hidl_vec<V1_1::Operation> & operations)2480 static hidl_vec<V1_0::Operation> uncheckedConvertToV1_0(
2481 const hidl_vec<V1_1::Operation>& operations) {
2482 hidl_vec<V1_0::Operation> result(operations.size());
2483 std::transform(
2484 operations.begin(), operations.end(), result.begin(),
2485 [](const V1_1::Operation& operation) { return uncheckedConvertToV1_0(operation); });
2486 return result;
2487 }
2488
convertToV1_1(const hidl_vec<V1_0::Operation> & operations)2489 static hidl_vec<V1_1::Operation> convertToV1_1(const hidl_vec<V1_0::Operation>& operations) {
2490 hidl_vec<V1_1::Operation> result(operations.size());
2491 std::transform(operations.begin(), operations.end(), result.begin(),
2492 [](const V1_0::Operation& operation) { return convertToV1_1(operation); });
2493 return result;
2494 }
2495
compliantWithV1_0(const V1_3::Operand & operand)2496 bool compliantWithV1_0(const V1_3::Operand& operand) {
2497 return validOperandType(static_cast<V1_0::OperandType>(operand.type)) &&
2498 (nonExtensionOperandTypeIsScalar(static_cast<int>(operand.type)) ||
2499 operand.dimensions.size() != 0) &&
2500 compliantWithV1_0(operand.lifetime);
2501 }
2502
compliantWithV1_2(const V1_3::Operand & operand)2503 bool compliantWithV1_2(const V1_3::Operand& operand) {
2504 return validOperandType(static_cast<V1_2::OperandType>(operand.type)) &&
2505 compliantWithV1_0(operand.lifetime);
2506 }
2507
compliantWithV1_3(const V1_3::Operand & operand)2508 bool compliantWithV1_3(const V1_3::Operand& operand) {
2509 return true;
2510 }
2511
compliantWith(HalVersion version,const V1_3::Model & model,std::set<uint32_t> * noncompliantOperations)2512 static bool compliantWith(HalVersion version, const V1_3::Model& model,
2513 std::set<uint32_t>* noncompliantOperations) {
2514 // A boolean vector indicating whether each pool is compliant with the target HAL version.
2515 std::vector<bool> isPoolCompliant(model.pools.size(), false);
2516 std::transform(model.pools.begin(), model.pools.end(), isPoolCompliant.begin(),
2517 [version](const hidl_memory& pool) { return validatePool(pool, version); });
2518
2519 // A boolean vector indicating whether each operand is compliant with the target HAL version.
2520 std::vector<bool> isOperandCompliant(model.main.operands.size(), false);
2521 std::transform(model.main.operands.begin(), model.main.operands.end(),
2522 isOperandCompliant.begin(), [&isPoolCompliant, version](const Operand& op) {
2523 bool is_operand_compliant = false;
2524 switch (version) {
2525 case HalVersion::UNKNOWN:
2526 is_operand_compliant = false;
2527 break;
2528 case HalVersion::V1_0:
2529 is_operand_compliant = compliantWithV1_0(op);
2530 break;
2531 case HalVersion::V1_1:
2532 // There is no V1_1::Operand -- both V1_0::Model
2533 // and V1_1::Model use V1_0::Operand.
2534 is_operand_compliant = compliantWithV1_0(op);
2535 break;
2536 case HalVersion::V1_2:
2537 is_operand_compliant = compliantWithV1_2(op);
2538 break;
2539 case HalVersion::V1_3:
2540 is_operand_compliant = compliantWithV1_3(op);
2541 break;
2542 }
2543 return is_operand_compliant &&
2544 !(op.lifetime == OperandLifeTime::CONSTANT_REFERENCE &&
2545 !isPoolCompliant[op.location.poolIndex]);
2546 });
2547
2548 auto allOperandsCompliant = [&isOperandCompliant](const hidl_vec<uint32_t>& indices) {
2549 return std::all_of(
2550 indices.begin(), indices.end(),
2551 [&isOperandCompliant](const uint32_t ind) { return isOperandCompliant[ind]; });
2552 };
2553
2554 auto localValidateOperation = [&model, version, &allOperandsCompliant](const Operation& op) {
2555 if (!allOperandsCompliant(op.inputs) || !allOperandsCompliant(op.outputs)) return false;
2556 int error = validateOperation(
2557 static_cast<int32_t>(op.type), op.inputs.size(),
2558 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
2559 op.outputs.size() > 0 ? op.outputs.data() : nullptr, model.main.operands, version);
2560 return error == ANEURALNETWORKS_NO_ERROR;
2561 };
2562
2563 if (noncompliantOperations) {
2564 CHECK(noncompliantOperations->empty());
2565 for (uint32_t idx = 0; idx < model.main.operations.size(); ++idx) {
2566 if (!localValidateOperation(model.main.operations[idx])) {
2567 noncompliantOperations->insert(idx);
2568 }
2569 }
2570 return noncompliantOperations->empty();
2571 } else {
2572 return std::all_of(model.main.operations.begin(), model.main.operations.end(),
2573 localValidateOperation);
2574 }
2575 }
2576
compliantWithV1_0(const V1_0::Model & model)2577 bool compliantWithV1_0(const V1_0::Model& model) {
2578 return true;
2579 }
2580
compliantWithV1_0(const V1_1::Model & model)2581 bool compliantWithV1_0(const V1_1::Model& model) {
2582 // In addition to new enumeration values being introduced in V1_1::Model, a
2583 // new flag was introduced to indicate whether or not float32 data can be
2584 // calculated using float16 units. This 'relaxComputationFloat32toFloat16'
2585 // flag is not relevant in whether a V1_1::Model is compliant with a
2586 // V1_0::Model because all 1.0 drivers require strict calculation by default
2587 // in the P NN runtime. Even if fp16 calculations are allowed, they can
2588 // still be computed by a strict fp32 driver.
2589 return std::all_of(
2590 model.operations.begin(), model.operations.end(), [&model](const V1_1::Operation& op) {
2591 int error = validateOperation(static_cast<int32_t>(op.type), op.inputs.size(),
2592 op.inputs.size() > 0 ? op.inputs.data() : nullptr,
2593 op.outputs.size(),
2594 op.outputs.size() > 0 ? op.outputs.data() : nullptr,
2595 convertToV1_3(model.operands), HalVersion::V1_0);
2596 return error == ANEURALNETWORKS_NO_ERROR;
2597 });
2598 }
2599
compliantWithV1_0(const V1_2::Model & model,std::set<uint32_t> * noncompliantOperations)2600 bool compliantWithV1_0(const V1_2::Model& model, std::set<uint32_t>* noncompliantOperations) {
2601 return compliantWith(HalVersion::V1_0, convertToV1_3(model), noncompliantOperations);
2602 }
2603
compliantWithV1_0(const V1_3::Model & model,std::set<uint32_t> * noncompliantOperations)2604 bool compliantWithV1_0(const V1_3::Model& model, std::set<uint32_t>* noncompliantOperations) {
2605 return compliantWith(HalVersion::V1_0, model, noncompliantOperations);
2606 }
2607
compliantWithV1_1(const V1_0::Model &)2608 bool compliantWithV1_1(const V1_0::Model&) {
2609 return true;
2610 }
2611
compliantWithV1_1(const V1_1::Model &)2612 bool compliantWithV1_1(const V1_1::Model&) {
2613 return true;
2614 }
2615
compliantWithV1_1(const V1_2::Model & model,std::set<uint32_t> * noncompliantOperations)2616 bool compliantWithV1_1(const V1_2::Model& model, std::set<uint32_t>* noncompliantOperations) {
2617 return compliantWith(HalVersion::V1_1, convertToV1_3(model), noncompliantOperations);
2618 }
2619
compliantWithV1_1(const V1_3::Model & model,std::set<uint32_t> * noncompliantOperations)2620 bool compliantWithV1_1(const V1_3::Model& model, std::set<uint32_t>* noncompliantOperations) {
2621 return compliantWith(HalVersion::V1_1, model, noncompliantOperations);
2622 }
2623
compliantWithV1_2(const V1_0::Model &)2624 bool compliantWithV1_2(const V1_0::Model&) {
2625 return true;
2626 }
2627
compliantWithV1_2(const V1_1::Model &)2628 bool compliantWithV1_2(const V1_1::Model&) {
2629 return true;
2630 }
2631
compliantWithV1_2(const V1_2::Model &,std::set<uint32_t> * noncompliantOperations)2632 bool compliantWithV1_2(const V1_2::Model&, std::set<uint32_t>* noncompliantOperations) {
2633 return true;
2634 }
2635
compliantWithV1_2(const V1_3::Model & model,std::set<uint32_t> * noncompliantOperations)2636 bool compliantWithV1_2(const V1_3::Model& model, std::set<uint32_t>* noncompliantOperations) {
2637 return compliantWith(HalVersion::V1_2, model, noncompliantOperations);
2638 }
2639
uncheckedConvertToV1_0(const V1_2::Operation & operation)2640 static V1_0::Operation uncheckedConvertToV1_0(const V1_2::Operation& operation) {
2641 return {.type = uncheckedConvertToV1_0(operation.type),
2642 .inputs = operation.inputs,
2643 .outputs = operation.outputs};
2644 }
2645
uncheckedConvertToV1_0(const V1_3::Operation & operation)2646 static V1_0::Operation uncheckedConvertToV1_0(const V1_3::Operation& operation) {
2647 return {.type = uncheckedConvertToV1_0(operation.type),
2648 .inputs = operation.inputs,
2649 .outputs = operation.outputs};
2650 }
2651
uncheckedConvertToV1_1(const V1_2::Operation & operation)2652 static V1_1::Operation uncheckedConvertToV1_1(const V1_2::Operation& operation) {
2653 return {.type = uncheckedConvertToV1_1(operation.type),
2654 .inputs = operation.inputs,
2655 .outputs = operation.outputs};
2656 }
2657
uncheckedConvertToV1_1(const V1_3::Operation & operation)2658 static V1_1::Operation uncheckedConvertToV1_1(const V1_3::Operation& operation) {
2659 return {.type = uncheckedConvertToV1_1(operation.type),
2660 .inputs = operation.inputs,
2661 .outputs = operation.outputs};
2662 }
2663
convertToV1_2(const V1_0::Operation & operation)2664 static V1_2::Operation convertToV1_2(const V1_0::Operation& operation) {
2665 return {.type = convertToV1_2(operation.type),
2666 .inputs = operation.inputs,
2667 .outputs = operation.outputs};
2668 }
2669
convertToV1_2(const V1_1::Operation & operation)2670 static V1_2::Operation convertToV1_2(const V1_1::Operation& operation) {
2671 return {.type = convertToV1_2(operation.type),
2672 .inputs = operation.inputs,
2673 .outputs = operation.outputs};
2674 }
2675
uncheckedConvertToV1_2(const V1_3::Operation & operation)2676 static V1_2::Operation uncheckedConvertToV1_2(const V1_3::Operation& operation) {
2677 return {.type = uncheckedConvertToV1_2(operation.type),
2678 .inputs = operation.inputs,
2679 .outputs = operation.outputs};
2680 }
2681
convertToV1_3(const V1_0::Operation & operation)2682 static V1_3::Operation convertToV1_3(const V1_0::Operation& operation) {
2683 return {.type = convertToV1_3(operation.type),
2684 .inputs = operation.inputs,
2685 .outputs = operation.outputs};
2686 }
2687
convertToV1_3(const V1_1::Operation & operation)2688 static V1_3::Operation convertToV1_3(const V1_1::Operation& operation) {
2689 return {.type = convertToV1_3(operation.type),
2690 .inputs = operation.inputs,
2691 .outputs = operation.outputs};
2692 }
2693
convertToV1_3(const V1_2::Operation & operation)2694 static V1_3::Operation convertToV1_3(const V1_2::Operation& operation) {
2695 return {.type = convertToV1_3(operation.type),
2696 .inputs = operation.inputs,
2697 .outputs = operation.outputs};
2698 }
2699
uncheckedConvertToV1_0(const hidl_vec<V1_3::Operation> & operations)2700 static hidl_vec<V1_0::Operation> uncheckedConvertToV1_0(
2701 const hidl_vec<V1_3::Operation>& operations) {
2702 hidl_vec<V1_0::Operation> result(operations.size());
2703 std::transform(
2704 operations.begin(), operations.end(), result.begin(),
2705 [](const V1_3::Operation& operation) { return uncheckedConvertToV1_0(operation); });
2706 return result;
2707 }
2708
uncheckedConvertToV1_0(const hidl_vec<V1_2::Operation> & operations)2709 static hidl_vec<V1_0::Operation> uncheckedConvertToV1_0(
2710 const hidl_vec<V1_2::Operation>& operations) {
2711 hidl_vec<V1_0::Operation> result(operations.size());
2712 std::transform(
2713 operations.begin(), operations.end(), result.begin(),
2714 [](const V1_2::Operation& operation) { return uncheckedConvertToV1_0(operation); });
2715 return result;
2716 }
2717
uncheckedConvertToV1_2(const hidl_vec<V1_3::Operation> & operations)2718 static hidl_vec<V1_2::Operation> uncheckedConvertToV1_2(
2719 const hidl_vec<V1_3::Operation>& operations) {
2720 hidl_vec<V1_2::Operation> result(operations.size());
2721 std::transform(
2722 operations.begin(), operations.end(), result.begin(),
2723 [](const V1_3::Operation& operation) { return uncheckedConvertToV1_2(operation); });
2724 return result;
2725 }
2726
uncheckedConvertToV1_1(const hidl_vec<V1_2::Operation> & operations)2727 static hidl_vec<V1_1::Operation> uncheckedConvertToV1_1(
2728 const hidl_vec<V1_2::Operation>& operations) {
2729 hidl_vec<V1_1::Operation> result(operations.size());
2730 std::transform(
2731 operations.begin(), operations.end(), result.begin(),
2732 [](const V1_2::Operation& operation) { return uncheckedConvertToV1_1(operation); });
2733 return result;
2734 }
2735
uncheckedConvertToV1_1(const hidl_vec<V1_3::Operation> & operations)2736 static hidl_vec<V1_1::Operation> uncheckedConvertToV1_1(
2737 const hidl_vec<V1_3::Operation>& operations) {
2738 hidl_vec<V1_1::Operation> result(operations.size());
2739 std::transform(
2740 operations.begin(), operations.end(), result.begin(),
2741 [](const V1_3::Operation& operation) { return uncheckedConvertToV1_1(operation); });
2742 return result;
2743 }
2744
convertToV1_2(const hidl_vec<V1_0::Operation> & operations)2745 static hidl_vec<V1_2::Operation> convertToV1_2(const hidl_vec<V1_0::Operation>& operations) {
2746 hidl_vec<V1_2::Operation> result(operations.size());
2747 std::transform(operations.begin(), operations.end(), result.begin(),
2748 [](const V1_0::Operation& operation) { return convertToV1_2(operation); });
2749 return result;
2750 }
2751
convertToV1_2(const hidl_vec<V1_1::Operation> & operations)2752 static hidl_vec<V1_2::Operation> convertToV1_2(const hidl_vec<V1_1::Operation>& operations) {
2753 hidl_vec<V1_2::Operation> result(operations.size());
2754 std::transform(operations.begin(), operations.end(), result.begin(),
2755 [](const V1_1::Operation& operation) { return convertToV1_2(operation); });
2756 return result;
2757 }
2758
convertToV1_3(const hidl_vec<V1_0::Operation> & operations)2759 static hidl_vec<V1_3::Operation> convertToV1_3(const hidl_vec<V1_0::Operation>& operations) {
2760 hidl_vec<V1_3::Operation> result(operations.size());
2761 std::transform(operations.begin(), operations.end(), result.begin(),
2762 [](const V1_0::Operation& operation) { return convertToV1_3(operation); });
2763 return result;
2764 }
2765
convertToV1_3(const hidl_vec<V1_1::Operation> & operations)2766 static hidl_vec<V1_3::Operation> convertToV1_3(const hidl_vec<V1_1::Operation>& operations) {
2767 hidl_vec<V1_3::Operation> result(operations.size());
2768 std::transform(operations.begin(), operations.end(), result.begin(),
2769 [](const V1_1::Operation& operation) { return convertToV1_3(operation); });
2770 return result;
2771 }
2772
convertToV1_3(const hidl_vec<V1_2::Operation> & operations)2773 static hidl_vec<V1_3::Operation> convertToV1_3(const hidl_vec<V1_2::Operation>& operations) {
2774 hidl_vec<V1_3::Operation> result(operations.size());
2775 std::transform(operations.begin(), operations.end(), result.begin(),
2776 [](const V1_2::Operation& operation) { return convertToV1_3(operation); });
2777 return result;
2778 }
2779
compliantWithV1_0(const V1_2::OperandType & operandType)2780 static bool compliantWithV1_0(const V1_2::OperandType& operandType) {
2781 return validOperandType(static_cast<V1_0::OperandType>(operandType));
2782 }
2783
compliantWithV1_0(const V1_3::OperandType & operandType)2784 static bool compliantWithV1_0(const V1_3::OperandType& operandType) {
2785 return validOperandType(static_cast<V1_0::OperandType>(operandType));
2786 }
2787
compliantWithV1_2(const V1_3::OperandType & operandType)2788 static bool compliantWithV1_2(const V1_3::OperandType& operandType) {
2789 return validOperandType(static_cast<V1_2::OperandType>(operandType));
2790 }
2791
convertToV1_0(const V1_2::OperandType & operandType)2792 V1_0::OperandType convertToV1_0(const V1_2::OperandType& operandType) {
2793 if (!compliantWithV1_0(operandType)) {
2794 LOG(ERROR) << "Upcasting non-compliant operand type " << toString(operandType)
2795 << " from V1_2::OperandType to V1_0::OperandType";
2796 }
2797 return static_cast<V1_0::OperandType>(operandType);
2798 }
2799
convertToV1_2(const V1_0::OperandType & operandType)2800 V1_2::OperandType convertToV1_2(const V1_0::OperandType& operandType) {
2801 return static_cast<V1_2::OperandType>(operandType);
2802 }
2803
convertToV1_2(const V1_3::OperandType & operandType)2804 V1_2::OperandType convertToV1_2(const V1_3::OperandType& operandType) {
2805 if (!compliantWithV1_2(operandType)) {
2806 LOG(ERROR) << "Upcasting non-compliant operand type " << toString(operandType)
2807 << " from V1_3::OperandType to V1_2::OperandType";
2808 }
2809 return static_cast<V1_2::OperandType>(operandType);
2810 }
2811
convertToV1_0(const V1_3::OperandType & operandType)2812 V1_0::OperandType convertToV1_0(const V1_3::OperandType& operandType) {
2813 if (!compliantWithV1_0(operandType)) {
2814 LOG(ERROR) << "Upcasting non-compliant operand type " << toString(operandType)
2815 << " from V1_3::Operand to V1_0::Operand";
2816 }
2817 return static_cast<V1_0::OperandType>(operandType);
2818 }
2819
compliantWithV1_0(hal::V1_0::OperandLifeTime lifetime)2820 bool compliantWithV1_0(hal::V1_0::OperandLifeTime lifetime) {
2821 return true;
2822 }
2823
compliantWithV1_0(hal::V1_3::OperandLifeTime lifetime)2824 bool compliantWithV1_0(hal::V1_3::OperandLifeTime lifetime) {
2825 return lifetime != V1_3::OperandLifeTime::SUBGRAPH;
2826 }
2827
compliantWithV1_3(hal::V1_0::OperandLifeTime lifetime)2828 bool compliantWithV1_3(hal::V1_0::OperandLifeTime lifetime) {
2829 return true;
2830 }
2831
compliantWithV1_3(hal::V1_3::OperandLifeTime lifetime)2832 bool compliantWithV1_3(hal::V1_3::OperandLifeTime lifetime) {
2833 return true;
2834 }
2835
convertToV1_0(V1_0::OperandLifeTime lifetime)2836 V1_0::OperandLifeTime convertToV1_0(V1_0::OperandLifeTime lifetime) {
2837 return lifetime;
2838 }
2839
convertToV1_0(V1_3::OperandLifeTime lifetime)2840 V1_0::OperandLifeTime convertToV1_0(V1_3::OperandLifeTime lifetime) {
2841 if (!compliantWithV1_0(lifetime)) {
2842 LOG(ERROR) << "Upcasting non-compliant lifetime " << toString(lifetime)
2843 << " from V1_3 to V1_0";
2844 }
2845 return static_cast<V1_0::OperandLifeTime>(lifetime);
2846 }
2847
convertToV1_3(V1_0::OperandLifeTime lifetime)2848 V1_3::OperandLifeTime convertToV1_3(V1_0::OperandLifeTime lifetime) {
2849 return static_cast<V1_3::OperandLifeTime>(lifetime);
2850 }
2851
convertToV1_3(V1_3::OperandLifeTime lifetime)2852 V1_3::OperandLifeTime convertToV1_3(V1_3::OperandLifeTime lifetime) {
2853 return lifetime;
2854 }
2855
convertToV1_0(const V1_2::Operand & operand)2856 V1_0::Operand convertToV1_0(const V1_2::Operand& operand) {
2857 return {.type = convertToV1_0(operand.type),
2858 .dimensions = operand.dimensions,
2859 .numberOfConsumers = operand.numberOfConsumers,
2860 .scale = operand.scale,
2861 .zeroPoint = operand.zeroPoint,
2862 .lifetime = convertToV1_0(operand.lifetime),
2863 .location = operand.location};
2864 }
2865
convertToV1_0(const V1_3::Operand & operand)2866 V1_0::Operand convertToV1_0(const V1_3::Operand& operand) {
2867 return {.type = convertToV1_0(operand.type),
2868 .dimensions = operand.dimensions,
2869 .numberOfConsumers = operand.numberOfConsumers,
2870 .scale = operand.scale,
2871 .zeroPoint = operand.zeroPoint,
2872 .lifetime = convertToV1_0(operand.lifetime),
2873 .location = operand.location};
2874 }
2875
convertToV1_2(const V1_0::Operand & operand)2876 V1_2::Operand convertToV1_2(const V1_0::Operand& operand) {
2877 return {.type = convertToV1_2(operand.type),
2878 .dimensions = operand.dimensions,
2879 .numberOfConsumers = operand.numberOfConsumers,
2880 .scale = operand.scale,
2881 .zeroPoint = operand.zeroPoint,
2882 .lifetime = operand.lifetime,
2883 .location = operand.location};
2884 }
2885
convertToV1_2(const V1_3::Operand & operand)2886 V1_2::Operand convertToV1_2(const V1_3::Operand& operand) {
2887 return {.type = convertToV1_2(operand.type),
2888 .dimensions = operand.dimensions,
2889 .numberOfConsumers = operand.numberOfConsumers,
2890 .scale = operand.scale,
2891 .zeroPoint = operand.zeroPoint,
2892 .lifetime = static_cast<V1_0::OperandLifeTime>(operand.lifetime),
2893 .location = operand.location,
2894 .extraParams = operand.extraParams};
2895 }
2896
convertToV1_3(const V1_0::Operand & operand)2897 V1_3::Operand convertToV1_3(const V1_0::Operand& operand) {
2898 return {.type = static_cast<V1_3::OperandType>(operand.type),
2899 .dimensions = operand.dimensions,
2900 .numberOfConsumers = operand.numberOfConsumers,
2901 .scale = operand.scale,
2902 .zeroPoint = operand.zeroPoint,
2903 .lifetime = convertToV1_3(operand.lifetime),
2904 .location = operand.location};
2905 }
2906
convertToV1_3(const V1_2::Operand & operand)2907 V1_3::Operand convertToV1_3(const V1_2::Operand& operand) {
2908 return {.type = static_cast<V1_3::OperandType>(operand.type),
2909 .dimensions = operand.dimensions,
2910 .numberOfConsumers = operand.numberOfConsumers,
2911 .scale = operand.scale,
2912 .zeroPoint = operand.zeroPoint,
2913 .lifetime = convertToV1_3(operand.lifetime),
2914 .location = operand.location,
2915 .extraParams = operand.extraParams};
2916 }
2917
convertToV1_3(const V1_3::Operand & operand)2918 V1_3::Operand convertToV1_3(const V1_3::Operand& operand) {
2919 return operand;
2920 }
2921
convertToV1_0(const hidl_vec<V1_0::Operand> & operands)2922 hidl_vec<V1_0::Operand> convertToV1_0(const hidl_vec<V1_0::Operand>& operands) {
2923 return operands;
2924 }
2925
convertToV1_0(const hidl_vec<V1_2::Operand> & operands)2926 hidl_vec<V1_0::Operand> convertToV1_0(const hidl_vec<V1_2::Operand>& operands) {
2927 hidl_vec<V1_0::Operand> result(operands.size());
2928 std::transform(operands.begin(), operands.end(), result.begin(),
2929 [](const V1_2::Operand& operand) { return convertToV1_0(operand); });
2930 return result;
2931 }
2932
convertToV1_0(const hidl_vec<V1_3::Operand> & operands)2933 hidl_vec<V1_0::Operand> convertToV1_0(const hidl_vec<V1_3::Operand>& operands) {
2934 hidl_vec<V1_0::Operand> result(operands.size());
2935 std::transform(operands.begin(), operands.end(), result.begin(),
2936 [](const V1_3::Operand& operand) { return convertToV1_0(operand); });
2937 return result;
2938 }
2939
convertToV1_2(const hidl_vec<V1_0::Operand> & operands)2940 hidl_vec<V1_2::Operand> convertToV1_2(const hidl_vec<V1_0::Operand>& operands) {
2941 hidl_vec<V1_2::Operand> result(operands.size());
2942 std::transform(operands.begin(), operands.end(), result.begin(),
2943 [](const V1_0::Operand& operand) { return convertToV1_2(operand); });
2944 return result;
2945 }
2946
convertToV1_2(const hidl_vec<V1_2::Operand> & operands)2947 hidl_vec<V1_2::Operand> convertToV1_2(const hidl_vec<V1_2::Operand>& operands) {
2948 return operands;
2949 }
2950
convertToV1_2(const hidl_vec<V1_3::Operand> & operands)2951 hidl_vec<V1_2::Operand> convertToV1_2(const hidl_vec<V1_3::Operand>& operands) {
2952 hidl_vec<V1_2::Operand> result(operands.size());
2953 std::transform(operands.begin(), operands.end(), result.begin(),
2954 [](const V1_3::Operand& operand) { return convertToV1_2(operand); });
2955 return result;
2956 }
2957
convertToV1_3(const hidl_vec<V1_0::Operand> & operands)2958 hidl_vec<V1_3::Operand> convertToV1_3(const hidl_vec<V1_0::Operand>& operands) {
2959 hidl_vec<V1_3::Operand> result(operands.size());
2960 std::transform(operands.begin(), operands.end(), result.begin(),
2961 [](const V1_0::Operand& operand) { return convertToV1_3(operand); });
2962 return result;
2963 }
2964
convertToV1_3(const hidl_vec<V1_2::Operand> & operands)2965 hidl_vec<V1_3::Operand> convertToV1_3(const hidl_vec<V1_2::Operand>& operands) {
2966 hidl_vec<V1_3::Operand> result(operands.size());
2967 std::transform(operands.begin(), operands.end(), result.begin(),
2968 [](const V1_2::Operand& operand) { return convertToV1_3(operand); });
2969 return result;
2970 }
2971
convertToV1_3(const hidl_vec<V1_3::Operand> & operands)2972 hidl_vec<V1_3::Operand> convertToV1_3(const hidl_vec<V1_3::Operand>& operands) {
2973 return operands;
2974 }
2975
convertToV1_0(const V1_0::Model & model)2976 V1_0::Model convertToV1_0(const V1_0::Model& model) {
2977 return model;
2978 }
2979
convertToV1_0(const V1_1::Model & model)2980 V1_0::Model convertToV1_0(const V1_1::Model& model) {
2981 if (!compliantWithV1_0(model)) {
2982 LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
2983 << " from V1_1::Model to V1_0::Model";
2984 }
2985 return {.operands = model.operands,
2986 .operations = uncheckedConvertToV1_0(model.operations),
2987 .inputIndexes = model.inputIndexes,
2988 .outputIndexes = model.outputIndexes,
2989 .operandValues = model.operandValues,
2990 .pools = model.pools};
2991 }
2992
convertToV1_0(const V1_2::Model & model)2993 V1_0::Model convertToV1_0(const V1_2::Model& model) {
2994 if (!compliantWithV1_0(model)) {
2995 LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
2996 << " from V1_2::Model to V1_0::Model";
2997 }
2998 return {.operands = convertToV1_0(model.operands),
2999 .operations = uncheckedConvertToV1_0(model.operations),
3000 .inputIndexes = model.inputIndexes,
3001 .outputIndexes = model.outputIndexes,
3002 .operandValues = model.operandValues,
3003 .pools = model.pools};
3004 }
3005
convertToV1_0(const V1_3::Model & model)3006 V1_0::Model convertToV1_0(const V1_3::Model& model) {
3007 if (!compliantWithV1_0(model)) {
3008 LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
3009 << " from V1_3::Model to V1_0::Model";
3010 }
3011 return {.operands = convertToV1_0(model.main.operands),
3012 .operations = uncheckedConvertToV1_0(model.main.operations),
3013 .inputIndexes = model.main.inputIndexes,
3014 .outputIndexes = model.main.outputIndexes,
3015 .operandValues = model.operandValues,
3016 .pools = model.pools};
3017 }
3018
convertToV1_1(const V1_0::Model & model)3019 V1_1::Model convertToV1_1(const V1_0::Model& model) {
3020 return {.operands = model.operands,
3021 .operations = convertToV1_1(model.operations),
3022 .inputIndexes = model.inputIndexes,
3023 .outputIndexes = model.outputIndexes,
3024 .operandValues = model.operandValues,
3025 .pools = model.pools,
3026 .relaxComputationFloat32toFloat16 = false};
3027 }
3028
convertToV1_1(const V1_1::Model & model)3029 V1_1::Model convertToV1_1(const V1_1::Model& model) {
3030 return model;
3031 }
3032
convertToV1_1(const V1_2::Model & model)3033 V1_1::Model convertToV1_1(const V1_2::Model& model) {
3034 if (!compliantWithV1_1(model)) {
3035 LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
3036 << " from V1_2::Model to V1_1::Model";
3037 }
3038 return {.operands = convertToV1_0(model.operands), // Operands in 1.1 and 1.0 are identical.
3039 .operations = uncheckedConvertToV1_1(model.operations),
3040 .inputIndexes = model.inputIndexes,
3041 .outputIndexes = model.outputIndexes,
3042 .operandValues = model.operandValues,
3043 .pools = model.pools,
3044 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16};
3045 }
3046
convertToV1_1(const V1_3::Model & model)3047 V1_1::Model convertToV1_1(const V1_3::Model& model) {
3048 if (!compliantWithV1_1(model)) {
3049 LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
3050 << " from V1_3::Model to V1_1::Model";
3051 }
3052 return {// Operands in 1.1 and 1.0 are identical.
3053 .operands = convertToV1_0(model.main.operands),
3054 .operations = uncheckedConvertToV1_1(model.main.operations),
3055 .inputIndexes = model.main.inputIndexes,
3056 .outputIndexes = model.main.outputIndexes,
3057 .operandValues = model.operandValues,
3058 .pools = model.pools,
3059 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16};
3060 }
3061
convertToV1_2(const V1_0::Model & model)3062 V1_2::Model convertToV1_2(const V1_0::Model& model) {
3063 return {.operands = convertToV1_2(model.operands),
3064 .operations = convertToV1_2(model.operations),
3065 .inputIndexes = model.inputIndexes,
3066 .outputIndexes = model.outputIndexes,
3067 .operandValues = model.operandValues,
3068 .pools = model.pools,
3069 .relaxComputationFloat32toFloat16 = false};
3070 }
3071
convertToV1_2(const V1_1::Model & model)3072 V1_2::Model convertToV1_2(const V1_1::Model& model) {
3073 return {.operands = convertToV1_2(model.operands),
3074 .operations = convertToV1_2(model.operations),
3075 .inputIndexes = model.inputIndexes,
3076 .outputIndexes = model.outputIndexes,
3077 .operandValues = model.operandValues,
3078 .pools = model.pools,
3079 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16};
3080 }
3081
convertToV1_2(const V1_2::Model & model)3082 V1_2::Model convertToV1_2(const V1_2::Model& model) {
3083 return model;
3084 }
3085
convertToV1_2(const V1_3::Model & model)3086 V1_2::Model convertToV1_2(const V1_3::Model& model) {
3087 if (!compliantWithV1_2(model)) {
3088 LOG(ERROR) << "Upcasting non-compliant model " << SHOW_IF_DEBUG(toString(model))
3089 << " from V1_3::Model to V1_2::Model";
3090 }
3091 return {.operands = convertToV1_2(model.main.operands),
3092 .operations = uncheckedConvertToV1_2(model.main.operations),
3093 .inputIndexes = model.main.inputIndexes,
3094 .outputIndexes = model.main.outputIndexes,
3095 .operandValues = model.operandValues,
3096 .pools = model.pools,
3097 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
3098 .extensionNameToPrefix = model.extensionNameToPrefix};
3099 }
3100
convertToV1_3(const V1_0::Model & model)3101 V1_3::Model convertToV1_3(const V1_0::Model& model) {
3102 return {.main = {.operands = convertToV1_3(model.operands),
3103 .operations = convertToV1_3(model.operations),
3104 .inputIndexes = model.inputIndexes,
3105 .outputIndexes = model.outputIndexes},
3106 .operandValues = model.operandValues,
3107 .pools = model.pools,
3108 .relaxComputationFloat32toFloat16 = false};
3109 }
3110
convertToV1_3(const V1_1::Model & model)3111 V1_3::Model convertToV1_3(const V1_1::Model& model) {
3112 return {.main = {.operands = convertToV1_3(model.operands),
3113 .operations = convertToV1_3(model.operations),
3114 .inputIndexes = model.inputIndexes,
3115 .outputIndexes = model.outputIndexes},
3116 .operandValues = model.operandValues,
3117 .pools = model.pools,
3118 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16};
3119 }
3120
convertToV1_3(const V1_2::Model & model)3121 V1_3::Model convertToV1_3(const V1_2::Model& model) {
3122 return {.main = {.operands = convertToV1_3(model.operands),
3123 .operations = convertToV1_3(model.operations),
3124 .inputIndexes = model.inputIndexes,
3125 .outputIndexes = model.outputIndexes},
3126 .operandValues = model.operandValues,
3127 .pools = model.pools,
3128 .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
3129 .extensionNameToPrefix = model.extensionNameToPrefix};
3130 }
3131
convertToV1_3(const V1_3::Model & model)3132 V1_3::Model convertToV1_3(const V1_3::Model& model) {
3133 return model;
3134 }
3135
compliantWithV1_0(const V1_0::Request & request)3136 bool compliantWithV1_0(const V1_0::Request& request) {
3137 return true;
3138 }
3139
compliantWithV1_0(const V1_3::Request & request)3140 bool compliantWithV1_0(const V1_3::Request& request) {
3141 return std::all_of(request.pools.begin(), request.pools.end(), [](const auto& pool) {
3142 if (pool.getDiscriminator() != V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) {
3143 return false;
3144 }
3145 const auto& name = pool.hidlMemory().name();
3146 return name == "ashmem" || name == "mmap_fd";
3147 });
3148 }
3149
compliantWithV1_2(const V1_3::Request & request)3150 bool compliantWithV1_2(const V1_3::Request& request) {
3151 return std::all_of(request.pools.begin(), request.pools.end(), [](const auto& pool) {
3152 if (pool.getDiscriminator() != V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory) {
3153 return false;
3154 }
3155 const auto& name = pool.hidlMemory().name();
3156 return name == "ashmem" || name == "mmap_fd" || name == "hardware_buffer_blob" ||
3157 name == "hardware_buffer";
3158 });
3159 }
3160
convertToV1_0(const V1_3::Request::MemoryPool & pool)3161 static hidl_memory convertToV1_0(const V1_3::Request::MemoryPool& pool) {
3162 switch (pool.getDiscriminator()) {
3163 case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory:
3164 return pool.hidlMemory();
3165 case V1_3::Request::MemoryPool::hidl_discriminator::token:
3166 return hidl_memory{};
3167 }
3168 }
3169
convertToV1_3(const hidl_memory & pool)3170 static V1_3::Request::MemoryPool convertToV1_3(const hidl_memory& pool) {
3171 V1_3::Request::MemoryPool ret;
3172 ret.hidlMemory(pool);
3173 return ret;
3174 }
3175
convertToV1_0(const V1_0::Request & request)3176 V1_0::Request convertToV1_0(const V1_0::Request& request) {
3177 return request;
3178 }
3179
uncheckedConvertToV1_0(const V1_3::Request & request)3180 static V1_0::Request uncheckedConvertToV1_0(const V1_3::Request& request) {
3181 hidl_vec<hidl_memory> pools(request.pools.size());
3182 std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
3183 [](const auto& pool) { return convertToV1_0(pool); });
3184 return {.inputs = request.inputs, .outputs = request.outputs, .pools = std::move(pools)};
3185 }
3186
convertToV1_0(const V1_3::Request & request)3187 V1_0::Request convertToV1_0(const V1_3::Request& request) {
3188 if (!compliantWithV1_0(request)) {
3189 LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
3190 << " from V1_3::Request to V1_0::Request of version 1.0";
3191 }
3192 return uncheckedConvertToV1_0(request);
3193 }
3194
convertToV1_2(const V1_3::Request & request)3195 V1_0::Request convertToV1_2(const V1_3::Request& request) {
3196 if (!compliantWithV1_2(request)) {
3197 LOG(ERROR) << "Upcasting non-compliant request " << SHOW_IF_DEBUG(toString(request))
3198 << " from V1_3::Request to V1_0::Request of version 1.2";
3199 }
3200 return uncheckedConvertToV1_0(request);
3201 }
3202
convertToV1_3(const V1_0::Request & request)3203 V1_3::Request convertToV1_3(const V1_0::Request& request) {
3204 hidl_vec<V1_3::Request::MemoryPool> pools(request.pools.size());
3205 std::transform(request.pools.begin(), request.pools.end(), pools.begin(),
3206 [](const auto& pool) { return convertToV1_3(pool); });
3207 return {.inputs = request.inputs, .outputs = request.outputs, .pools = std::move(pools)};
3208 }
3209
convertToV1_3(const V1_3::Request & request)3210 V1_3::Request convertToV1_3(const V1_3::Request& request) {
3211 return request;
3212 }
3213
syncWait(int fd,int timeout)3214 FenceState syncWait(int fd, int timeout) {
3215 // This implementation is directly based on the ::sync_wait() implementation.
3216
3217 struct pollfd fds;
3218 int ret;
3219
3220 if (fd < 0) {
3221 errno = EINVAL;
3222 return FenceState::UNKNOWN;
3223 }
3224
3225 fds.fd = fd;
3226 fds.events = POLLIN;
3227
3228 do {
3229 ret = poll(&fds, 1, timeout);
3230 if (ret > 0) {
3231 if (fds.revents & POLLNVAL) {
3232 errno = EINVAL;
3233 return FenceState::UNKNOWN;
3234 }
3235 if (fds.revents & POLLERR) {
3236 errno = EINVAL;
3237 return FenceState::ERROR;
3238 }
3239 return FenceState::SIGNALED;
3240 } else if (ret == 0) {
3241 errno = ETIME;
3242 return FenceState::ACTIVE;
3243 }
3244 } while (ret == -1 && (errno == EINTR || errno == EAGAIN));
3245
3246 return FenceState::UNKNOWN;
3247 }
3248
3249 #ifdef NN_DEBUGGABLE
getProp(const char * str,uint32_t defaultValue)3250 uint32_t getProp(const char* str, uint32_t defaultValue) {
3251 const std::string propStr = android::base::GetProperty(str, "");
3252 if (propStr.size() > 0) {
3253 return std::stoi(propStr);
3254 } else {
3255 return defaultValue;
3256 }
3257 }
3258 #endif // NN_DEBUGGABLE
3259
3260 } // namespace nn
3261 } // namespace android
3262