1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ValidateHal"
18 
19 #include "ValidateHal.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <algorithm>
24 #include <set>
25 #include <utility>
26 #include <vector>
27 
28 #include "NeuralNetworks.h"
29 #include "OperationsUtils.h"
30 #include "Tracing.h"
31 #include "Utils.h"
32 
33 namespace android {
34 namespace nn {
35 
36 using namespace hal;
37 
38 template <class T_Model>
39 struct ModelToHalVersion;
40 template <>
41 struct ModelToHalVersion<V1_0::Model> {
42     static constexpr HalVersion version = HalVersion::V1_0;
43 };
44 template <>
45 struct ModelToHalVersion<V1_1::Model> {
46     static constexpr HalVersion version = HalVersion::V1_1;
47 };
48 template <>
49 struct ModelToHalVersion<V1_2::Model> {
50     static constexpr HalVersion version = HalVersion::V1_2;
51 };
52 template <>
53 struct ModelToHalVersion<V1_3::Model> {
54     static constexpr HalVersion version = HalVersion::V1_3;
55 };
56 
57 class MemoryAccessVerifier {
58    public:
MemoryAccessVerifier(const hidl_vec<hidl_memory> & pools)59     MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
60         : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
61         for (size_t i = 0; i < mPoolCount; i++) {
62             mPoolSizes[i] = pools[i].size();
63         }
64     }
MemoryAccessVerifier(const hidl_vec<V1_3::Request::MemoryPool> & pools)65     MemoryAccessVerifier(const hidl_vec<V1_3::Request::MemoryPool>& pools)
66         : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
67         for (size_t i = 0; i < mPoolCount; i++) {
68             switch (pools[i].getDiscriminator()) {
69                 case Request::MemoryPool::hidl_discriminator::hidlMemory:
70                     mPoolSizes[i] = pools[i].hidlMemory().size();
71                     break;
72                 case Request::MemoryPool::hidl_discriminator::token:
73                     // Set size to 0 to enforce length == 0 && offset == 0.
74                     mPoolSizes[i] = 0;
75                     break;
76             }
77         }
78     }
validate(const DataLocation & location) const79     bool validate(const DataLocation& location) const {
80         if (location.poolIndex >= mPoolCount) {
81             LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
82             return false;
83         }
84         const size_t size = mPoolSizes[location.poolIndex];
85         // Do the addition using size_t to avoid potential wrap-around problems.
86         if (static_cast<size_t>(location.offset) + location.length > size) {
87             LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
88                        << location.offset << " and length " << location.length
89                        << " exceeds pool size of " << size;
90             return false;
91         }
92         return true;
93     }
94 
95    private:
96     size_t mPoolCount;
97     std::vector<size_t> mPoolSizes;
98 };
99 
validateOperandExtraParams(const V1_3::Operand & operand,uint32_t index)100 static bool validateOperandExtraParams(const V1_3::Operand& operand, uint32_t index) {
101     switch (operand.type) {
102         case OperandType::FLOAT32:
103         case OperandType::INT32:
104         case OperandType::UINT32:
105         case OperandType::BOOL:
106         case OperandType::SUBGRAPH:
107         case OperandType::TENSOR_FLOAT32:
108         case OperandType::TENSOR_FLOAT16:
109         case OperandType::TENSOR_INT32:
110         case OperandType::TENSOR_QUANT8_ASYMM:
111         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
112         case OperandType::TENSOR_QUANT8_SYMM:
113         case OperandType::TENSOR_QUANT16_ASYMM:
114         case OperandType::TENSOR_QUANT16_SYMM:
115         case OperandType::TENSOR_BOOL8: {
116             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
117                          OperandExtraParams::hidl_discriminator::none)
118                     << "Operand " << index << ": Operand of type "
119                     << getOperandTypeName(operand.type)
120                     << " has incorrect extraParams: " << toString(operand.extraParams);
121         } break;
122         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
123             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
124                          OperandExtraParams::hidl_discriminator::channelQuant)
125                     << "Operand " << index << ": Operand of type "
126                     << getOperandTypeName(operand.type) << " without a Channel Quantization params";
127             auto& channelQuant = operand.extraParams.channelQuant();
128 
129             size_t count = operand.dimensions.size();
130             NN_RET_CHECK_LT(channelQuant.channelDim, count)
131                     << "Operand " << index << ": Operand of type "
132                     << getOperandTypeName(operand.type)
133                     << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
134                     << ", must be valid dimension index in range [0, " << count << ")";
135             uint32_t expected = operand.dimensions[channelQuant.channelDim];
136             NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
137                     << "Operand " << index << ": Operand of type "
138                     << getOperandTypeName(operand.type) << " with a wrong-sized scales, "
139                     << "expected " << expected << " was " << channelQuant.scales.size();
140             NN_RET_CHECK_NE(expected, 0)
141                     << "Operand " << index << ": Operand of type "
142                     << getOperandTypeName(operand.type) << " channel dimension "
143                     << channelQuant.channelDim << " is underspecified (can't be 0)";
144             for (uint32_t i = 0; i < expected; ++i) {
145                 NN_RET_CHECK_GT(channelQuant.scales[i], .0f)
146                         << "Operand " << index << ": Operand of type "
147                         << getOperandTypeName(operand.type) << " with a negative value in scales["
148                         << i << "]=" << channelQuant.scales[i];
149             }
150         } break;
151         default: {
152             if (isExtensionOperandType(operand.type)) {
153                 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
154                                      OperandExtraParams::hidl_discriminator::extension ||
155                              operand.extraParams.getDiscriminator() ==
156                                      OperandExtraParams::hidl_discriminator::none)
157                         << "Operand " << index << ": Extension operand of type "
158                         << getOperandTypeName(operand.type)
159                         << " has incorrect extraParams: " << toString(operand.extraParams);
160             }
161             // No validation for OEM types.
162         } break;
163     }
164     return true;
165 }
166 
167 template <typename VersionedOperand>
validateOperands(const hidl_vec<VersionedOperand> & operands,const hidl_vec<uint8_t> & operandValues,const hidl_vec<hidl_memory> & pools,const hidl_vec<Subgraph> & subgraphs,bool allowUnspecifiedRank)168 static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
169                              const hidl_vec<uint8_t>& operandValues,
170                              const hidl_vec<hidl_memory>& pools,
171                              const hidl_vec<Subgraph>& subgraphs, bool allowUnspecifiedRank) {
172     uint32_t index = 0;
173     MemoryAccessVerifier poolVerifier(pools);
174     for (auto& versionedOperand : operands) {
175         if (!validOperandType(versionedOperand.type)) {
176             LOG(ERROR) << "Operand is not supported by this version: "
177                        << toString(versionedOperand.type);
178             return false;
179         }
180         // Once we are sure the operand is supported by its version, it is safe
181         // to convert it to the latest version for the rest of the validations.
182         V1_3::Operand operand = convertToV1_3(versionedOperand);
183         // Validate type and dimensions.
184         switch (operand.type) {
185             case OperandType::FLOAT16:
186             case OperandType::FLOAT32:
187             case OperandType::INT32:
188             case OperandType::UINT32:
189             case OperandType::BOOL:
190             case OperandType::SUBGRAPH:
191             case OperandType::OEM: {
192                 size_t count = operand.dimensions.size();
193                 if (count != 0) {
194                     LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
195                                << count;
196                     return false;
197                 }
198                 break;
199             }
200             case OperandType::TENSOR_FLOAT16:
201             case OperandType::TENSOR_FLOAT32:
202             case OperandType::TENSOR_INT32:
203             case OperandType::TENSOR_QUANT8_ASYMM:
204             case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
205             case OperandType::TENSOR_QUANT8_SYMM:
206             case OperandType::TENSOR_QUANT16_ASYMM:
207             case OperandType::TENSOR_QUANT16_SYMM:
208             case OperandType::TENSOR_BOOL8:
209             case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
210             case OperandType::TENSOR_OEM_BYTE: {
211                 if ((!allowUnspecifiedRank || operand.lifetime == OperandLifeTime::CONSTANT_COPY ||
212                      operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE) &&
213                     operand.dimensions.size() == 0) {
214                     LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
215                     return false;
216                 }
217                 break;
218             }
219             default: {
220                 if (!isExtensionOperandType(operand.type)) {
221                     LOG(ERROR) << "Operand " << index << ": Invalid operand type "
222                                << toString(operand.type);
223                     return false;
224                 }
225             } break;
226         }
227 
228         // Validate the scale.
229         switch (operand.type) {
230             case OperandType::FLOAT16:
231             case OperandType::FLOAT32:
232             case OperandType::INT32:
233             case OperandType::UINT32:
234             case OperandType::BOOL:
235             case OperandType::SUBGRAPH:
236             case OperandType::TENSOR_FLOAT16:
237             case OperandType::TENSOR_FLOAT32:
238             case OperandType::TENSOR_BOOL8:
239             case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
240                 if (operand.scale != 0.f) {
241                     LOG(ERROR) << "Operand " << index << ": Operand of type "
242                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
243                                << operand.scale << ")";
244                     return false;
245                 }
246                 break;
247             case OperandType::TENSOR_INT32:
248                 // TENSOR_INT32 may be used with or without scale, depending on the operation.
249                 if (operand.scale < 0.f) {
250                     LOG(ERROR) << "Operand " << index << ": Operand of type "
251                                << getOperandTypeName(operand.type) << " with a negative scale";
252                     return false;
253                 }
254                 break;
255             case OperandType::TENSOR_QUANT8_ASYMM:
256             case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
257             case OperandType::TENSOR_QUANT8_SYMM:
258             case OperandType::TENSOR_QUANT16_ASYMM:
259             case OperandType::TENSOR_QUANT16_SYMM:
260                 if (operand.scale <= 0.f) {
261                     LOG(ERROR) << "Operand " << index << ": Operand of type "
262                                << getOperandTypeName(operand.type) << " with a non-positive scale";
263                     return false;
264                 }
265                 break;
266             default:
267                 if (isExtensionOperandType(operand.type) && operand.scale != 0.f) {
268                     LOG(ERROR) << "Operand " << index << ": Operand of type "
269                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
270                                << operand.scale << ")";
271                     return false;
272                 }
273                 // No validation for OEM types.
274                 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
275                 break;
276         }
277 
278         // Validate the zeroPoint.
279         switch (operand.type) {
280             case OperandType::FLOAT16:
281             case OperandType::FLOAT32:
282             case OperandType::INT32:
283             case OperandType::UINT32:
284             case OperandType::BOOL:
285             case OperandType::SUBGRAPH:
286             case OperandType::TENSOR_FLOAT16:
287             case OperandType::TENSOR_FLOAT32:
288             case OperandType::TENSOR_INT32:
289             case OperandType::TENSOR_BOOL8:
290             case OperandType::TENSOR_QUANT8_SYMM:
291             case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
292                 if (operand.zeroPoint != 0) {
293                     LOG(ERROR) << "Operand " << index << ": Operand of type "
294                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
295                                << operand.zeroPoint;
296                     return false;
297                 }
298                 break;
299             case OperandType::TENSOR_QUANT8_ASYMM:
300                 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
301                     LOG(ERROR) << "Operand " << index << ": Operand of type "
302                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
303                                << operand.zeroPoint << ", must be in range [0, 255]";
304                     return false;
305                 }
306                 break;
307             case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
308                 if (operand.zeroPoint < -128 || operand.zeroPoint > 127) {
309                     LOG(ERROR) << "Operand " << index << ": Operand of type "
310                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
311                                << operand.zeroPoint << ", must be in range [-128, 127]";
312                     return false;
313                 }
314                 break;
315             case OperandType::TENSOR_QUANT16_ASYMM:
316                 if (operand.zeroPoint < 0 || operand.zeroPoint > 65535) {
317                     LOG(ERROR) << "Operand " << index << ": Operand of type "
318                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
319                                << operand.zeroPoint << ", must be in range [0, 65535]";
320                     return false;
321                 }
322                 break;
323             case OperandType::TENSOR_QUANT16_SYMM:
324                 if (operand.zeroPoint != 0) {
325                     LOG(ERROR) << "Operand " << index << ": Operand of type "
326                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
327                                << operand.zeroPoint;
328                     return false;
329                 }
330                 break;
331             default:
332                 if (isExtensionOperandType(operand.type) && operand.zeroPoint != 0) {
333                     LOG(ERROR) << "Operand " << index << ": Operand of type "
334                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
335                                << operand.zeroPoint;
336                     return false;
337                 }
338                 // No validation for OEM types.
339                 break;
340         }
341 
342         NN_RET_CHECK(validateOperandExtraParams(operand, index));
343 
344         // Validate the lifetime and the location.
345         const DataLocation& location = operand.location;
346         switch (operand.lifetime) {
347             case OperandLifeTime::CONSTANT_COPY:
348                 if (location.poolIndex != 0) {
349                     LOG(ERROR) << "Operand " << index
350                                << ": CONSTANT_COPY with a non-zero poolIndex "
351                                << location.poolIndex;
352                     return false;
353                 }
354                 // Do the addition using size_t to avoid potential wrap-around problems.
355                 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
356                     LOG(ERROR) << "Operand " << index
357                                << ": OperandValue location out of range.  Starts at "
358                                << location.offset << ", length " << location.length << ", max "
359                                << operandValues.size();
360                     return false;
361                 }
362                 break;
363             case OperandLifeTime::CONSTANT_REFERENCE:
364                 if (!poolVerifier.validate(location)) {
365                     return false;
366                 }
367                 break;
368             case OperandLifeTime::TEMPORARY_VARIABLE:
369             case OperandLifeTime::SUBGRAPH_INPUT:
370             case OperandLifeTime::SUBGRAPH_OUTPUT:
371             case OperandLifeTime::NO_VALUE:
372                 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
373                     LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
374                                << location.poolIndex << ", offset " << location.offset
375                                << ", or length " << location.length << " for operand of lifetime "
376                                << toString(operand.lifetime);
377                     return false;
378                 }
379                 break;
380             case OperandLifeTime::SUBGRAPH: {
381                 if (location.poolIndex != 0) {
382                     LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero poolIndex "
383                                << location.poolIndex;
384                     return false;
385                 }
386                 if (location.offset >= subgraphs.size()) {
387                     LOG(ERROR) << "Subgraph index out of range: " << location.offset
388                                << " >= " << subgraphs.size();
389                     return false;
390                 }
391                 if (location.length != 0) {
392                     LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero length "
393                                << location.length;
394                     return false;
395                 }
396             } break;
397             default:
398                 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
399                            << toString(operand.lifetime);
400                 return false;
401         }
402 
403         // Make sure SUBGRAPH operand type and lifetime always go together.
404         if ((operand.type == OperandType::SUBGRAPH) !=
405             (operand.lifetime == OperandLifeTime::SUBGRAPH)) {
406             LOG(ERROR) << "Operand " << index << ": Operand of type " << toString(operand.type)
407                        << " cannot have lifetime " << toString(operand.lifetime);
408             return false;
409         }
410 
411         // For constants, validate that the length is as expected. The other lifetimes
412         // expect the length to be 0. Don't validate for OEM types.
413         if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
414             operand.lifetime == OperandLifeTime::CONSTANT_COPY) {
415             if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM &&
416                 operand.type != OperandType::TENSOR_OEM_BYTE) {
417                 uint32_t expectedLength = nonExtensionOperandSizeOfData(operand);
418                 if (location.length != expectedLength) {
419                     LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
420                                << " expected a size of " << expectedLength << " but got "
421                                << location.length;
422                     return false;
423                 }
424             }
425         }
426 
427         index++;
428     }
429     return true;
430 }
431 
getHalVersion(const V1_0::Operation &)432 static HalVersion getHalVersion(const V1_0::Operation&) {
433     return HalVersion::V1_0;
434 }
435 
getHalVersion(const V1_1::Operation &)436 static HalVersion getHalVersion(const V1_1::Operation&) {
437     return HalVersion::V1_1;
438 }
439 
getHalVersion(const V1_2::Operation &)440 static HalVersion getHalVersion(const V1_2::Operation&) {
441     return HalVersion::V1_2;
442 }
443 
getHalVersion(const V1_3::Operation &)444 static HalVersion getHalVersion(const V1_3::Operation&) {
445     return HalVersion::V1_3;
446 }
447 
448 template <typename VersionedOperation>
validateOperations(const hidl_vec<VersionedOperation> & operations,const hidl_vec<Operand> & operands,const hidl_vec<Subgraph> & subgraphs,ValidationMode mode)449 static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
450                                const hidl_vec<Operand>& operands,
451                                const hidl_vec<Subgraph>& subgraphs, ValidationMode mode) {
452     auto isValidSubgraphReference = [&subgraphs](const Operand& modelOperand) -> bool {
453         NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
454                 << "Unexpected operand type: " << toString(modelOperand.type);
455         NN_RET_CHECK_LT(modelOperand.location.offset, subgraphs.size())
456                 << "Invalid subgraph reference";
457         return true;
458     };
459     auto getSubgraph = [&subgraphs](const Operand& modelOperand) -> const Subgraph* {
460         CHECK_LT(modelOperand.location.offset, subgraphs.size());
461         return &subgraphs[modelOperand.location.offset];
462     };
463     auto getInputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
464         return getSubgraph(modelOperand)->inputIndexes.size();
465     };
466     auto getOutputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
467         return getSubgraph(modelOperand)->outputIndexes.size();
468     };
469     auto getInputOperand = [&getSubgraph](const Operand& modelOperand,
470                                           uint32_t index) -> const Operand* {
471         const Subgraph& subgraph = *getSubgraph(modelOperand);
472         CHECK_LT(subgraph.inputIndexes[index], subgraph.operands.size());
473         return &subgraph.operands[subgraph.inputIndexes[index]];
474     };
475     auto getOutputOperand = [&getSubgraph](const Operand& modelOperand,
476                                            uint32_t index) -> const Operand* {
477         const Subgraph& subgraph = *getSubgraph(modelOperand);
478         CHECK_LT(subgraph.outputIndexes[index], subgraph.operands.size());
479         return &subgraph.operands[subgraph.outputIndexes[index]];
480     };
481     for (auto& op : operations) {
482         // TODO Validate the shapes and any known values. This is currently
483         // done in CpuExecutor but should be done here for all drivers.
484         int error = validateOperation(
485                 static_cast<int32_t>(op.type), op.inputs.size(),
486                 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
487                 op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op),
488                 {.isValidSubgraphReference = isValidSubgraphReference,
489                  .getSubgraphInputCount = getInputCount,
490                  .getSubgraphOutputCount = getOutputCount,
491                  .getSubgraphInputOperand = getInputOperand,
492                  .getSubgraphOutputOperand = getOutputOperand,
493                  // 1.3 HAL does not support CF operations with operands of
494                  // unknown size. See http://b/132458982#comment63.
495                  .allowControlFlowOperationWithOperandOfUnknownSize =
496                          mode == ValidationMode::RUNTIME});
497         if (error != ANEURALNETWORKS_NO_ERROR) {
498             LOG(ERROR) << "Invalid operation " << toString(op.type);
499             return false;
500         }
501 
502         // This is redundant because of the checks in validateGraph(),
503         // but it is retained here in order to emit more informative
504         // error messages.
505         for (uint32_t i : op.outputs) {
506             const Operand& operand = operands[i];
507             if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE &&
508                 operand.lifetime != OperandLifeTime::SUBGRAPH_OUTPUT) {
509                 LOG(ERROR) << "Writing to operand " << i << " with incompatible lifetime "
510                            << toString(operand.lifetime);
511                 return false;
512             }
513         }
514     }
515     return true;
516 }
517 
validatePool(const hidl_memory & pool,HalVersion ver)518 bool validatePool(const hidl_memory& pool, HalVersion ver) {
519     const auto& name = pool.name();
520     if (name != "ashmem" && name != "mmap_fd" &&
521         ((ver < HalVersion::V1_2) ||
522          (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
523         LOG(ERROR) << "Unsupported memory type " << name;
524         return false;
525     }
526     if (pool.handle() == nullptr) {
527         LOG(ERROR) << "Memory of type " << name << " is null";
528         return false;
529     }
530     return true;
531 }
532 
validatePool(const V1_3::Request::MemoryPool & pool,HalVersion ver)533 bool validatePool(const V1_3::Request::MemoryPool& pool, HalVersion ver) {
534     switch (pool.getDiscriminator()) {
535         case Request::MemoryPool::hidl_discriminator::hidlMemory:
536             return validatePool(pool.hidlMemory(), ver);
537         case Request::MemoryPool::hidl_discriminator::token:
538             return pool.token() > 0;
539     }
540     LOG(FATAL) << "unknown MemoryPool discriminator";
541     return false;
542 }
543 
544 template <class T_MemoryPool>
validatePools(const hidl_vec<T_MemoryPool> & pools,HalVersion ver)545 static bool validatePools(const hidl_vec<T_MemoryPool>& pools, HalVersion ver) {
546     return std::all_of(pools.begin(), pools.end(),
547                        [ver](const auto& pool) { return validatePool(pool, ver); });
548 }
549 
validateModelInputOutputs(const hidl_vec<uint32_t> indexes,const hidl_vec<Operand> & operands,OperandLifeTime lifetime)550 static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
551                                       const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
552     const size_t operandCount = operands.size();
553     for (uint32_t i : indexes) {
554         if (i >= operandCount) {
555             LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
556             return false;
557         }
558         const Operand& operand = operands[i];
559         if (operand.lifetime != lifetime) {
560             LOG(ERROR) << "Model input or output operand " << i << " has lifetime of "
561                        << toString(operand.lifetime) << " instead of the expected "
562                        << toString(lifetime);
563             return false;
564         }
565     }
566 
567     std::vector<uint32_t> sortedIndexes = indexes;
568     std::sort(sortedIndexes.begin(), sortedIndexes.end());
569     auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
570     if (adjacentI != sortedIndexes.end()) {
571         LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
572         return false;
573     }
574 
575     for (size_t i = 0; i < operands.size(); ++i) {
576         if (operands[i].lifetime == lifetime &&
577             !binary_search(sortedIndexes.begin(), sortedIndexes.end(), i)) {
578             LOG(ERROR) << "Operand " << i << " marked as " << toString(lifetime)
579                        << " but is not included in Model input or output indexes";
580             return false;
581         }
582     }
583 
584     return true;
585 }
586 
587 template <typename VersionedModelOrSubgraph>
validateGraph(const VersionedModelOrSubgraph & model)588 static bool validateGraph(const VersionedModelOrSubgraph& model) {
589     // set up counts
590     std::vector<uint32_t> operandNumberOfConsumers(model.operands.size(), 0);
591     //     Either the operand has a known value before model execution
592     //     begins, or we've seen a writer for this operand while
593     //     walking operands in execution order.
594     std::vector<bool> operandValueKnown(model.operands.size(), false);
595 
596     // mark known operands
597     for (size_t i = 0; i < model.operands.size(); ++i) {
598         const auto& operand = model.operands[i];
599         const OperandLifeTime lifetime = convertToV1_3(operand.lifetime);
600         operandValueKnown[i] = lifetime == OperandLifeTime::SUBGRAPH_INPUT ||
601                                lifetime == OperandLifeTime::CONSTANT_COPY ||
602                                lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
603                                lifetime == OperandLifeTime::NO_VALUE ||
604                                lifetime == OperandLifeTime::SUBGRAPH;
605     }
606 
607     // Validate that operations are sorted into execution order.
608     //
609     // If there is a cycle in the graph, the operations will not
610     // appear to be sorted into execution order: Some operation will
611     // have an input for which operandValueKnown[] is false.
612     for (size_t i = 0; i < model.operations.size(); ++i) {
613         const auto& operation = model.operations[i];
614 
615         for (size_t j = 0; j < operation.inputs.size(); ++j) {
616             uint32_t k = operation.inputs[j];
617             if (!operandValueKnown[k]) {
618                 LOG(ERROR) << "Operation " << i << " input " << j << " (operand " << k
619                            << ") is read before it is written";
620                 return false;
621             }
622             operandNumberOfConsumers[k]++;
623         }
624 
625         for (size_t j = 0; j < operation.outputs.size(); ++j) {
626             uint32_t k = operation.outputs[j];
627             if (operandValueKnown[k]) {
628                 // Assuming validateOperations() has returned true, we
629                 // know that this output is TEMPORARY_VARIABLE or
630                 // MODEL_OUTPUT, and so the only way
631                 // operandValueKnown[k] can be true is if we've
632                 // already seen a writer for this operand.
633                 LOG(ERROR) << "Operation " << i << " output " << j << " (operand " << k
634                            << ") has already been written";
635                 return false;
636             }
637             operandValueKnown[k] = true;
638         }
639     }
640 
641     // validate number of consumers
642     //
643     // TODO Because we have to validate it, there was no point in including it
644     // in struct Operand. For the next release, consider removing unless we have
645     // an additional process in system space that creates this value. In that
646     // case, it would not have to be validated.
647     for (size_t i = 0; i < model.operands.size(); ++i) {
648         if (model.operands[i].numberOfConsumers != operandNumberOfConsumers[i]) {
649             LOG(ERROR) << "Operand " << i << " has incorrect number of consumers "
650                        << model.operands[i].numberOfConsumers << ", expected "
651                        << operandNumberOfConsumers[i];
652             return false;
653         }
654     }
655 
656     // verify all operands are written
657     for (size_t i = 0; i < model.operands.size(); ++i) {
658         if (!operandValueKnown[i]) {
659             LOG(ERROR) << "Operand " << i << " is never written";
660             return false;
661         }
662     }
663 
664     return true;
665 }
666 
667 // Makes sure the model does not contain subgraph reference cycles.
checkNoReferenceCycles(const V1_3::Model & model,const V1_3::Subgraph & subgraph,std::set<const V1_3::Subgraph * > * path)668 static bool checkNoReferenceCycles(const V1_3::Model& model, const V1_3::Subgraph& subgraph,
669                                    std::set<const V1_3::Subgraph*>* path) {
670     auto [_, isNew] = path->insert(&subgraph);
671     if (!isNew) {
672         LOG(ERROR) << "Model contains a circular subgraph reference";
673         return false;
674     }
675     for (const Operand& operand : subgraph.operands) {
676         if (operand.lifetime == OperandLifeTime::SUBGRAPH) {
677             uint32_t refSubgraphIndex = operand.location.offset;
678             if (!checkNoReferenceCycles(model, model.referenced[refSubgraphIndex], path)) {
679                 return false;
680             }
681         }
682     }
683     path->erase(&subgraph);
684     return true;
685 }
686 
checkNoReferenceCycles(const V1_3::Model & model)687 static bool checkNoReferenceCycles(const V1_3::Model& model) {
688     std::set<const V1_3::Subgraph*> path;
689     return checkNoReferenceCycles(model, model.main, &path);
690 }
691 
692 template <class T_Model>
validateModel(const T_Model & model,ValidationMode mode)693 bool validateModel(const T_Model& model, ValidationMode mode) {
694     NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
695     HalVersion version = ModelToHalVersion<T_Model>::version;
696     if (model.operations.size() == 0 || model.operands.size() == 0) {
697         LOG(ERROR) << "Invalid empty model.";
698         return false;
699     }
700     // We only need versioned operands for their validation. For all the other
701     // validations we can use operands upcasted to the latest version.
702     const hidl_vec<Operand> latestVersionOperands = convertToV1_3(model.operands);
703     return (validateOperands(model.operands, model.operandValues, model.pools, /*subgraphs=*/{},
704                              /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
705             validateOperations(model.operations, latestVersionOperands, /*subgraphs=*/{}, mode) &&
706             validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
707                                       OperandLifeTime::SUBGRAPH_INPUT) &&
708             validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
709                                       OperandLifeTime::SUBGRAPH_OUTPUT) &&
710             validatePools(model.pools, version) && validateGraph(model));
711 }
712 
713 template bool validateModel<V1_0::Model>(const V1_0::Model& model, ValidationMode mode);
714 template bool validateModel<V1_1::Model>(const V1_1::Model& model, ValidationMode mode);
715 template bool validateModel<V1_2::Model>(const V1_2::Model& model, ValidationMode mode);
716 
717 template <>
validateModel(const V1_3::Model & model,ValidationMode mode)718 bool validateModel(const V1_3::Model& model, ValidationMode mode) {
719     NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
720     if (model.main.operations.size() == 0 || model.main.operands.size() == 0) {
721         LOG(ERROR) << "Invalid empty model.";
722         return false;
723     }
724     auto validateSubgraph = [&model, mode](const Subgraph& subgraph) -> bool {
725         return (validateOperands(subgraph.operands, model.operandValues, model.pools,
726                                  model.referenced, /*allowUnspecifiedRank=*/true) &&
727                 validateOperations(subgraph.operations, subgraph.operands, model.referenced,
728                                    mode) &&
729                 validateModelInputOutputs(subgraph.inputIndexes, subgraph.operands,
730                                           OperandLifeTime::SUBGRAPH_INPUT) &&
731                 validateModelInputOutputs(subgraph.outputIndexes, subgraph.operands,
732                                           OperandLifeTime::SUBGRAPH_OUTPUT) &&
733                 validateGraph(subgraph));
734     };
735     return (validateSubgraph(model.main) &&
736             std::all_of(model.referenced.begin(), model.referenced.end(), validateSubgraph) &&
737             validatePools(model.pools, HalVersion::V1_3) && checkNoReferenceCycles(model));
738 }
739 
740 // Validates the arguments of a request. type is either "input" or "output" and is used
741 // for printing error messages. The operandIndexes is the appropriate array of input
742 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
validateRequestArguments(const hidl_vec<RequestArgument> & requestArguments,const hidl_vec<uint32_t> & operandIndexes,const hidl_vec<Operand> & operands,const MemoryAccessVerifier & poolVerifier,bool allowUnspecified,const char * type)743 static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments,
744                                      const hidl_vec<uint32_t>& operandIndexes,
745                                      const hidl_vec<Operand>& operands,
746                                      const MemoryAccessVerifier& poolVerifier,
747                                      bool allowUnspecified, const char* type) {
748     // The request should specify as many arguments as were described in the model.
749     const size_t requestArgumentCount = requestArguments.size();
750     if (requestArgumentCount != operandIndexes.size()) {
751         LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
752                    << "s but the model has " << operandIndexes.size();
753         return false;
754     }
755     for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
756          requestArgumentIndex++) {
757         const RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
758         const DataLocation& location = requestArgument.location;
759         // Get the operand index for this argument. We extract it from the list
760         // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
761         // We assume in this function that the model has been validated already.
762         const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
763         const Operand& operand = operands[operandIndex];
764         if (requestArgument.hasNoValue) {
765             if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
766                 requestArgument.dimensions.size() != 0) {
767                 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
768                            << " has no value yet has details.";
769                 return false;
770             }
771         } else {
772             // Validate the location.
773             if (!poolVerifier.validate(location)) {
774                 return false;
775             }
776             // If the argument specified a dimension, validate it.
777             uint32_t modelRank = operand.dimensions.size();
778             uint32_t requestRank = requestArgument.dimensions.size();
779             if (requestRank == 0) {
780                 if (!allowUnspecified) {
781                     // Validate that all the dimensions are specified in the model.
782                     for (size_t i = 0; i < modelRank; i++) {
783                         if (operand.dimensions[i] == 0) {
784                             LOG(ERROR) << "Model has dimension " << i
785                                        << " set to 0 but the request does specify the dimension.";
786                             return false;
787                         }
788                     }
789                 }
790             } else {
791                 if (modelRank != 0 && requestRank != modelRank) {
792                     LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
793                                << " has number of dimensions (" << requestRank
794                                << ") different than the model's (" << modelRank << ")";
795                     return false;
796                 }
797                 for (size_t i = 0; i < requestRank; i++) {
798                     if (modelRank != 0 && requestArgument.dimensions[i] != operand.dimensions[i] &&
799                         operand.dimensions[i] != 0) {
800                         LOG(ERROR)
801                                 << "Request " << type << " " << requestArgumentIndex
802                                 << " has dimension " << i << " of " << requestArgument.dimensions[i]
803                                 << " different than the model's " << operand.dimensions[i];
804                         return false;
805                     }
806                     if (requestArgument.dimensions[i] == 0 && !allowUnspecified) {
807                         LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
808                                    << " has dimension " << i << " of zero";
809                         return false;
810                     }
811                 }
812             }
813         }
814     }
815     return true;
816 }
817 
818 template <class T_Request, class T_Model>
validateRequest(const T_Request & request,const T_Model & model,bool allowUnspecifiedOutput)819 bool validateRequest(const T_Request& request, const T_Model& model, bool allowUnspecifiedOutput) {
820     HalVersion version = ModelToHalVersion<T_Model>::version;
821     MemoryAccessVerifier poolVerifier(request.pools);
822     return (validateRequestArguments(request.inputs, model.inputIndexes,
823                                      convertToV1_3(model.operands), poolVerifier,
824                                      /*allowUnspecified=*/false, "input") &&
825             validateRequestArguments(
826                     request.outputs, model.outputIndexes, convertToV1_3(model.operands),
827                     poolVerifier,
828                     /*allowUnspecified=*/version >= HalVersion::V1_2 && allowUnspecifiedOutput,
829                     "output") &&
830             validatePools(request.pools, version));
831 }
832 
833 template bool validateRequest<V1_0::Request, V1_0::Model>(const V1_0::Request& request,
834                                                           const V1_0::Model& model,
835                                                           bool allowUnspecifiedOutput);
836 template bool validateRequest<V1_0::Request, V1_1::Model>(const V1_0::Request& request,
837                                                           const V1_1::Model& model,
838                                                           bool allowUnspecifiedOutput);
839 template bool validateRequest<V1_0::Request, V1_2::Model>(const V1_0::Request& request,
840                                                           const V1_2::Model& model,
841                                                           bool allowUnspecifiedOutput);
842 
843 template <>
validateRequest(const V1_3::Request & request,const V1_3::Model & model,bool allowUnspecifiedOutput)844 bool validateRequest(const V1_3::Request& request, const V1_3::Model& model,
845                      bool allowUnspecifiedOutput) {
846     return (validateRequestArguments(request.inputs, model.main.inputIndexes, model.main.operands,
847                                      request.pools,
848                                      /*allowUnspecified=*/false, "input") &&
849             validateRequestArguments(request.outputs, model.main.outputIndexes, model.main.operands,
850                                      request.pools, allowUnspecifiedOutput, "output") &&
851             validatePools(request.pools, HalVersion::V1_3));
852 }
853 
validateMemoryDesc(const V1_3::BufferDesc & desc,const hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hidl_vec<V1_3::BufferRole> & inputRoles,const hidl_vec<V1_3::BufferRole> & outputRoles,std::function<const V1_3::Model * (const sp<V1_3::IPreparedModel> &)> getModel,std::set<PreparedModelRole> * preparedModelRoles,V1_3::Operand * combinedOperand)854 bool validateMemoryDesc(const V1_3::BufferDesc& desc,
855                         const hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
856                         const hidl_vec<V1_3::BufferRole>& inputRoles,
857                         const hidl_vec<V1_3::BufferRole>& outputRoles,
858                         std::function<const V1_3::Model*(const sp<V1_3::IPreparedModel>&)> getModel,
859                         std::set<PreparedModelRole>* preparedModelRoles,
860                         V1_3::Operand* combinedOperand) {
861     NN_RET_CHECK(preparedModels.size() != 0);
862     NN_RET_CHECK(inputRoles.size() != 0 || outputRoles.size() != 0);
863 
864     std::set<PreparedModelRole> roles;
865     std::vector<V1_3::Operand> operands;
866     operands.reserve(inputRoles.size() + outputRoles.size());
867     for (const auto& role : inputRoles) {
868         NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
869         const auto& preparedModel = preparedModels[role.modelIndex];
870         NN_RET_CHECK(preparedModel != nullptr);
871         const auto* model = getModel(preparedModel);
872         NN_RET_CHECK(model != nullptr);
873         const auto& inputIndexes = model->main.inputIndexes;
874         NN_RET_CHECK_LT(role.ioIndex, inputIndexes.size());
875         NN_RET_CHECK_GT(role.frequency, 0.0f);
876         NN_RET_CHECK_LE(role.frequency, 1.0f);
877         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
878         NN_RET_CHECK(success);
879         operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
880     }
881     for (const auto& role : outputRoles) {
882         NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
883         const auto& preparedModel = preparedModels[role.modelIndex];
884         NN_RET_CHECK(preparedModel != nullptr);
885         const auto* model = getModel(preparedModel);
886         NN_RET_CHECK(model != nullptr);
887         const auto& outputIndexes = model->main.outputIndexes;
888         NN_RET_CHECK_LT(role.ioIndex, outputIndexes.size());
889         NN_RET_CHECK_GT(role.frequency, 0.0f);
890         NN_RET_CHECK_LE(role.frequency, 1.0f);
891         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
892         NN_RET_CHECK(success);
893         operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
894     }
895 
896     CHECK(!operands.empty());
897     const auto opType = operands[0].type;
898     const bool isExtension = isExtensionOperandType(opType);
899 
900     std::vector<uint32_t> dimensions = desc.dimensions;
901     for (const auto& operand : operands) {
902         NN_RET_CHECK(operand.type == operands[0].type)
903                 << toString(operand.type) << " vs " << toString(operands[0].type);
904         NN_RET_CHECK_EQ(operand.scale, operands[0].scale);
905         NN_RET_CHECK_EQ(operand.zeroPoint, operands[0].zeroPoint);
906         // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
907         if (!isExtension) {
908             NN_RET_CHECK(operand.extraParams == operands[0].extraParams)
909                     << toString(operand.extraParams) << " vs " << toString(operands[0].extraParams);
910         }
911         const auto combined = combineDimensions(dimensions, operand.dimensions);
912         NN_RET_CHECK(combined.has_value());
913         dimensions = combined.value();
914     }
915 
916     // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
917     if (!isExtension) {
918         NN_RET_CHECK(!nonExtensionOperandTypeIsScalar(static_cast<int>(opType)) ||
919                      dimensions.empty())
920                 << "invalid dimensions with scalar operand type.";
921     }
922 
923     if (preparedModelRoles != nullptr) {
924         *preparedModelRoles = std::move(roles);
925     }
926     if (combinedOperand != nullptr) {
927         *combinedOperand = operands[0];
928         combinedOperand->dimensions = dimensions;
929     }
930     return true;
931 }
932 
validateExecutionPreference(ExecutionPreference preference)933 bool validateExecutionPreference(ExecutionPreference preference) {
934     return preference == ExecutionPreference::LOW_POWER ||
935            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
936            preference == ExecutionPreference::SUSTAINED_SPEED;
937 }
938 
validatePriority(Priority priority)939 bool validatePriority(Priority priority) {
940     return priority == Priority::LOW || priority == Priority::MEDIUM || priority == Priority::HIGH;
941 }
942 
validOperandType(V1_0::OperandType operandType)943 bool validOperandType(V1_0::OperandType operandType) {
944     switch (operandType) {
945         case V1_0::OperandType::FLOAT32:
946         case V1_0::OperandType::INT32:
947         case V1_0::OperandType::UINT32:
948         case V1_0::OperandType::TENSOR_FLOAT32:
949         case V1_0::OperandType::TENSOR_INT32:
950         case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
951         case V1_0::OperandType::OEM:
952         case V1_0::OperandType::TENSOR_OEM_BYTE:
953             return true;
954         default:
955             return false;
956     }
957 }
958 
validOperandType(V1_2::OperandType operandType)959 bool validOperandType(V1_2::OperandType operandType) {
960     switch (operandType) {
961         case V1_2::OperandType::FLOAT16:
962         case V1_2::OperandType::FLOAT32:
963         case V1_2::OperandType::INT32:
964         case V1_2::OperandType::UINT32:
965         case V1_2::OperandType::BOOL:
966         case V1_2::OperandType::TENSOR_FLOAT16:
967         case V1_2::OperandType::TENSOR_FLOAT32:
968         case V1_2::OperandType::TENSOR_INT32:
969         case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
970         case V1_2::OperandType::TENSOR_QUANT8_SYMM:
971         case V1_2::OperandType::TENSOR_QUANT16_ASYMM:
972         case V1_2::OperandType::TENSOR_QUANT16_SYMM:
973         case V1_2::OperandType::TENSOR_BOOL8:
974         case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
975         case V1_2::OperandType::OEM:
976         case V1_2::OperandType::TENSOR_OEM_BYTE:
977             return true;
978         default:
979             return isExtensionOperandType(static_cast<V1_3::OperandType>(operandType));
980     }
981 }
982 
validOperandType(V1_3::OperandType operandType)983 bool validOperandType(V1_3::OperandType operandType) {
984     switch (operandType) {
985         case V1_3::OperandType::FLOAT16:
986         case V1_3::OperandType::FLOAT32:
987         case V1_3::OperandType::INT32:
988         case V1_3::OperandType::UINT32:
989         case V1_3::OperandType::BOOL:
990         case V1_3::OperandType::TENSOR_FLOAT16:
991         case V1_3::OperandType::TENSOR_FLOAT32:
992         case V1_3::OperandType::TENSOR_INT32:
993         case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
994         case V1_3::OperandType::TENSOR_QUANT8_SYMM:
995         case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
996         case V1_3::OperandType::TENSOR_QUANT16_SYMM:
997         case V1_3::OperandType::TENSOR_BOOL8:
998         case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
999         case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
1000         case V1_3::OperandType::SUBGRAPH:
1001         case V1_3::OperandType::OEM:
1002         case V1_3::OperandType::TENSOR_OEM_BYTE:
1003             return true;
1004         default:
1005             return isExtensionOperandType(operandType);
1006     }
1007 }
1008 
1009 }  // namespace nn
1010 }  // namespace android
1011