1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "TestHarness.h"
18 
19 #include <android-base/logging.h>
20 #include <gmock/gmock-matchers.h>
21 #include <gtest/gtest.h>
22 
23 #include <algorithm>
24 #include <cmath>
25 #include <functional>
26 #include <limits>
27 #include <map>
28 #include <numeric>
29 #include <set>
30 #include <string>
31 #include <vector>
32 
33 namespace test_helper {
34 
35 namespace {
36 
37 template <typename T>
38 constexpr bool nnIsFloat = std::is_floating_point_v<T> || std::is_same_v<T, _Float16>;
39 
40 constexpr uint32_t kMaxNumberOfPrintedErrors = 10;
41 
42 // TODO(b/139442217): Allow passing accuracy criteria from spec.
43 // Currently we only need relaxed accuracy criteria on mobilenet tests, so we return the quant8
44 // tolerance simply based on the current test name.
getQuant8AllowedError()45 int getQuant8AllowedError() {
46     const ::testing::TestInfo* const testInfo =
47             ::testing::UnitTest::GetInstance()->current_test_info();
48     const std::string testCaseName = testInfo->test_case_name();
49     const std::string testName = testInfo->name();
50     // We relax the quant8 precision for all tests with mobilenet:
51     // - CTS/VTS GeneratedTest and DynamicOutputShapeTest with mobilenet
52     // - VTS CompilationCachingTest and CompilationCachingSecurityTest except for TOCTOU tests
53     if (testName.find("mobilenet") != std::string::npos ||
54         (testCaseName.find("CompilationCaching") != std::string::npos &&
55          testName.find("TOCTOU") == std::string::npos)) {
56         return 3;
57     } else {
58         return 1;
59     }
60 }
61 
getNumberOfElements(const TestOperand & op)62 uint32_t getNumberOfElements(const TestOperand& op) {
63     return std::reduce(op.dimensions.begin(), op.dimensions.end(), 1u, std::multiplies<uint32_t>());
64 }
65 
66 // Check if the actual results meet the accuracy criterion.
67 template <typename T>
expectNear(const TestOperand & op,const TestBuffer & result,const AccuracyCriterion & criterion,bool allowInvalid=false)68 void expectNear(const TestOperand& op, const TestBuffer& result, const AccuracyCriterion& criterion,
69                 bool allowInvalid = false) {
70     constexpr uint32_t kMinNumberOfElementsToTestBiasMSE = 10;
71     const T* actualBuffer = result.get<T>();
72     const T* expectedBuffer = op.data.get<T>();
73     uint32_t len = getNumberOfElements(op), numErrors = 0, numSkip = 0;
74     double bias = 0.0f, mse = 0.0f;
75     for (uint32_t i = 0; i < len; i++) {
76         // Compare all data types in double for precision and signed arithmetic.
77         double actual = static_cast<double>(actualBuffer[i]);
78         double expected = static_cast<double>(expectedBuffer[i]);
79         double tolerableRange = criterion.atol + criterion.rtol * std::fabs(expected);
80         EXPECT_FALSE(std::isnan(expected));
81 
82         // Skip invalid floating point values.
83         if (allowInvalid &&
84             (std::isinf(expected) || (std::is_same_v<T, float> && std::fabs(expected) > 1e3) ||
85              (std::is_same_v<T, _Float16> || std::fabs(expected) > 1e2))) {
86             numSkip++;
87             continue;
88         }
89 
90         // Accumulate bias and MSE. Use relative bias and MSE for floating point values.
91         double diff = actual - expected;
92         if constexpr (nnIsFloat<T>) {
93             diff /= std::max(1.0, std::abs(expected));
94         }
95         bias += diff;
96         mse += diff * diff;
97 
98         // Print at most kMaxNumberOfPrintedErrors errors by EXPECT_NEAR.
99         if (numErrors < kMaxNumberOfPrintedErrors) {
100             EXPECT_NEAR(expected, actual, tolerableRange) << "When comparing element " << i;
101         }
102         if (std::fabs(actual - expected) > tolerableRange) numErrors++;
103     }
104     EXPECT_EQ(numErrors, 0u);
105 
106     // Test bias and MSE.
107     if (len < numSkip + kMinNumberOfElementsToTestBiasMSE) return;
108     bias /= static_cast<double>(len - numSkip);
109     mse /= static_cast<double>(len - numSkip);
110     EXPECT_LE(std::fabs(bias), criterion.bias);
111     EXPECT_LE(mse, criterion.mse);
112 }
113 
114 // For boolean values, we expect the number of mismatches does not exceed a certain ratio.
expectBooleanNearlyEqual(const TestOperand & op,const TestBuffer & result,float allowedErrorRatio)115 void expectBooleanNearlyEqual(const TestOperand& op, const TestBuffer& result,
116                               float allowedErrorRatio) {
117     const bool8* actualBuffer = result.get<bool8>();
118     const bool8* expectedBuffer = op.data.get<bool8>();
119     uint32_t len = getNumberOfElements(op), numErrors = 0;
120     std::stringstream errorMsg;
121     for (uint32_t i = 0; i < len; i++) {
122         if (expectedBuffer[i] != actualBuffer[i]) {
123             if (numErrors < kMaxNumberOfPrintedErrors)
124                 errorMsg << "    Expected: " << expectedBuffer[i] << ", actual: " << actualBuffer[i]
125                          << ", when comparing element " << i << "\n";
126             numErrors++;
127         }
128     }
129     // When |len| is small, the allowedErrorCount will intentionally ceil at 1, which allows for
130     // greater tolerance.
131     uint32_t allowedErrorCount = static_cast<uint32_t>(std::ceil(allowedErrorRatio * len));
132     EXPECT_LE(numErrors, allowedErrorCount) << errorMsg.str();
133 }
134 
135 // Calculates the expected probability from the unnormalized log-probability of
136 // each class in the input and compares it to the actual occurrence of that class
137 // in the output.
expectMultinomialDistributionWithinTolerance(const TestModel & model,const std::vector<TestBuffer> & buffers)138 void expectMultinomialDistributionWithinTolerance(const TestModel& model,
139                                                   const std::vector<TestBuffer>& buffers) {
140     // This function is only for RANDOM_MULTINOMIAL single-operation test.
141     CHECK_EQ(model.referenced.size(), 0u) << "Subgraphs not supported";
142     ASSERT_EQ(model.main.operations.size(), 1u);
143     ASSERT_EQ(model.main.operations[0].type, TestOperationType::RANDOM_MULTINOMIAL);
144     ASSERT_EQ(model.main.inputIndexes.size(), 1u);
145     ASSERT_EQ(model.main.outputIndexes.size(), 1u);
146     ASSERT_EQ(buffers.size(), 1u);
147 
148     const auto& inputOperand = model.main.operands[model.main.inputIndexes[0]];
149     const auto& outputOperand = model.main.operands[model.main.outputIndexes[0]];
150     ASSERT_EQ(inputOperand.dimensions.size(), 2u);
151     ASSERT_EQ(outputOperand.dimensions.size(), 2u);
152 
153     const int kBatchSize = inputOperand.dimensions[0];
154     const int kNumClasses = inputOperand.dimensions[1];
155     const int kNumSamples = outputOperand.dimensions[1];
156 
157     const uint32_t outputLength = getNumberOfElements(outputOperand);
158     const int32_t* outputData = buffers[0].get<int32_t>();
159     std::vector<int> classCounts(kNumClasses);
160     for (uint32_t i = 0; i < outputLength; i++) {
161         classCounts[outputData[i]]++;
162     }
163 
164     const uint32_t inputLength = getNumberOfElements(inputOperand);
165     std::vector<float> inputData(inputLength);
166     if (inputOperand.type == TestOperandType::TENSOR_FLOAT32) {
167         const float* inputRaw = inputOperand.data.get<float>();
168         std::copy(inputRaw, inputRaw + inputLength, inputData.begin());
169     } else if (inputOperand.type == TestOperandType::TENSOR_FLOAT16) {
170         const _Float16* inputRaw = inputOperand.data.get<_Float16>();
171         std::transform(inputRaw, inputRaw + inputLength, inputData.begin(),
172                        [](_Float16 fp16) { return static_cast<float>(fp16); });
173     } else {
174         FAIL() << "Unknown input operand type for RANDOM_MULTINOMIAL.";
175     }
176 
177     for (int b = 0; b < kBatchSize; ++b) {
178         float probabilitySum = 0;
179         const int batchIndex = kBatchSize * b;
180         for (int i = 0; i < kNumClasses; ++i) {
181             probabilitySum += expf(inputData[batchIndex + i]);
182         }
183         for (int i = 0; i < kNumClasses; ++i) {
184             float probability =
185                     static_cast<float>(classCounts[i]) / static_cast<float>(kNumSamples);
186             float probabilityExpected = expf(inputData[batchIndex + i]) / probabilitySum;
187             EXPECT_THAT(probability,
188                         ::testing::FloatNear(probabilityExpected,
189                                              model.expectedMultinomialDistributionTolerance));
190         }
191     }
192 }
193 
194 }  // namespace
195 
checkResults(const TestModel & model,const std::vector<TestBuffer> & buffers,const AccuracyCriteria & criteria)196 void checkResults(const TestModel& model, const std::vector<TestBuffer>& buffers,
197                   const AccuracyCriteria& criteria) {
198     ASSERT_EQ(model.main.outputIndexes.size(), buffers.size());
199     for (uint32_t i = 0; i < model.main.outputIndexes.size(); i++) {
200         const uint32_t outputIndex = model.main.outputIndexes[i];
201         SCOPED_TRACE(testing::Message()
202                      << "When comparing output " << i << " (op" << outputIndex << ")");
203         const auto& operand = model.main.operands[outputIndex];
204         const auto& result = buffers[i];
205         if (operand.isIgnored) continue;
206 
207         switch (operand.type) {
208             case TestOperandType::TENSOR_FLOAT32:
209                 expectNear<float>(operand, result, criteria.float32, criteria.allowInvalidFpValues);
210                 break;
211             case TestOperandType::TENSOR_FLOAT16:
212                 expectNear<_Float16>(operand, result, criteria.float16,
213                                      criteria.allowInvalidFpValues);
214                 break;
215             case TestOperandType::TENSOR_INT32:
216             case TestOperandType::INT32:
217                 expectNear<int32_t>(operand, result, criteria.int32);
218                 break;
219             case TestOperandType::TENSOR_QUANT8_ASYMM:
220                 expectNear<uint8_t>(operand, result, criteria.quant8Asymm);
221                 break;
222             case TestOperandType::TENSOR_QUANT8_SYMM:
223                 expectNear<int8_t>(operand, result, criteria.quant8Symm);
224                 break;
225             case TestOperandType::TENSOR_QUANT16_ASYMM:
226                 expectNear<uint16_t>(operand, result, criteria.quant16Asymm);
227                 break;
228             case TestOperandType::TENSOR_QUANT16_SYMM:
229                 expectNear<int16_t>(operand, result, criteria.quant16Symm);
230                 break;
231             case TestOperandType::TENSOR_BOOL8:
232                 expectBooleanNearlyEqual(operand, result, criteria.bool8AllowedErrorRatio);
233                 break;
234             case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
235                 expectNear<int8_t>(operand, result, criteria.quant8AsymmSigned);
236                 break;
237             default:
238                 FAIL() << "Data type not supported.";
239         }
240     }
241 }
242 
checkResults(const TestModel & model,const std::vector<TestBuffer> & buffers)243 void checkResults(const TestModel& model, const std::vector<TestBuffer>& buffers) {
244     // For RANDOM_MULTINOMIAL test only.
245     if (model.expectedMultinomialDistributionTolerance > 0.0f) {
246         expectMultinomialDistributionWithinTolerance(model, buffers);
247         return;
248     }
249 
250     // Decide the default tolerable range.
251     //
252     // For floating-point models, we use the relaxed precision if either
253     // - relaxed computation flag is set
254     // - the model has at least one TENSOR_FLOAT16 operand
255     //
256     // The bias and MSE criteria are implicitly set to the maximum -- we do not enforce these
257     // criteria in normal generated tests.
258     //
259     // TODO: Adjust the error limit based on testing.
260     //
261     AccuracyCriteria criteria = {
262             // The relative tolerance is 5ULP of FP32.
263             .float32 = {.atol = 1e-5, .rtol = 5.0f * 1.1920928955078125e-7},
264             // Both the absolute and relative tolerance are 5ULP of FP16.
265             .float16 = {.atol = 5.0f * 0.0009765625, .rtol = 5.0f * 0.0009765625},
266             .int32 = {.atol = 1},
267             .quant8Asymm = {.atol = 1},
268             .quant8Symm = {.atol = 1},
269             .quant16Asymm = {.atol = 1},
270             .quant16Symm = {.atol = 1},
271             .bool8AllowedErrorRatio = 0.0f,
272             // Since generated tests are hand-calculated, there should be no invalid FP values.
273             .allowInvalidFpValues = false,
274     };
275     bool hasFloat16Inputs = false;
276     model.forEachSubgraph([&hasFloat16Inputs](const TestSubgraph& subgraph) {
277         if (!hasFloat16Inputs) {
278             hasFloat16Inputs = std::any_of(subgraph.operands.begin(), subgraph.operands.end(),
279                                            [](const TestOperand& op) {
280                                                return op.type == TestOperandType::TENSOR_FLOAT16;
281                                            });
282         }
283     });
284     if (model.isRelaxed || hasFloat16Inputs) {
285         criteria.float32 = criteria.float16;
286     }
287     const double quant8AllowedError = getQuant8AllowedError();
288     criteria.quant8Asymm.atol = quant8AllowedError;
289     criteria.quant8AsymmSigned.atol = quant8AllowedError;
290     criteria.quant8Symm.atol = quant8AllowedError;
291 
292     checkResults(model, buffers, criteria);
293 }
294 
convertQuant8AsymmOperandsToSigned(const TestModel & testModel)295 TestModel convertQuant8AsymmOperandsToSigned(const TestModel& testModel) {
296     auto processSubgraph = [](TestSubgraph* subgraph) {
297         for (TestOperand& operand : subgraph->operands) {
298             if (operand.type == test_helper::TestOperandType::TENSOR_QUANT8_ASYMM) {
299                 operand.type = test_helper::TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED;
300                 operand.zeroPoint -= 128;
301                 const uint8_t* inputOperandData = operand.data.get<uint8_t>();
302                 int8_t* outputOperandData = operand.data.getMutable<int8_t>();
303                 for (size_t i = 0; i < operand.data.size(); ++i) {
304                     outputOperandData[i] =
305                             static_cast<int8_t>(static_cast<int32_t>(inputOperandData[i]) - 128);
306                 }
307             }
308         }
309     };
310     TestModel converted(testModel.copy());
311     processSubgraph(&converted.main);
312     for (TestSubgraph& subgraph : converted.referenced) {
313         processSubgraph(&subgraph);
314     }
315     return converted;
316 }
317 
isQuantizedType(TestOperandType type)318 bool isQuantizedType(TestOperandType type) {
319     static const std::set<TestOperandType> kQuantizedTypes = {
320             TestOperandType::TENSOR_QUANT8_ASYMM,
321             TestOperandType::TENSOR_QUANT8_SYMM,
322             TestOperandType::TENSOR_QUANT16_ASYMM,
323             TestOperandType::TENSOR_QUANT16_SYMM,
324             TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL,
325             TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED,
326     };
327     return kQuantizedTypes.count(type) > 0;
328 }
329 
isFloatType(TestOperandType type)330 bool isFloatType(TestOperandType type) {
331     static const std::set<TestOperandType> kFloatTypes = {
332             TestOperandType::TENSOR_FLOAT32,
333             TestOperandType::TENSOR_FLOAT16,
334             TestOperandType::FLOAT32,
335             TestOperandType::FLOAT16,
336     };
337     return kFloatTypes.count(type) > 0;
338 }
339 
isConstant(TestOperandLifeTime lifetime)340 bool isConstant(TestOperandLifeTime lifetime) {
341     return lifetime == TestOperandLifeTime::CONSTANT_COPY ||
342            lifetime == TestOperandLifeTime::CONSTANT_REFERENCE;
343 }
344 
345 namespace {
346 
347 const char* kOperationTypeNames[] = {
348         "ADD",
349         "AVERAGE_POOL_2D",
350         "CONCATENATION",
351         "CONV_2D",
352         "DEPTHWISE_CONV_2D",
353         "DEPTH_TO_SPACE",
354         "DEQUANTIZE",
355         "EMBEDDING_LOOKUP",
356         "FLOOR",
357         "FULLY_CONNECTED",
358         "HASHTABLE_LOOKUP",
359         "L2_NORMALIZATION",
360         "L2_POOL",
361         "LOCAL_RESPONSE_NORMALIZATION",
362         "LOGISTIC",
363         "LSH_PROJECTION",
364         "LSTM",
365         "MAX_POOL_2D",
366         "MUL",
367         "RELU",
368         "RELU1",
369         "RELU6",
370         "RESHAPE",
371         "RESIZE_BILINEAR",
372         "RNN",
373         "SOFTMAX",
374         "SPACE_TO_DEPTH",
375         "SVDF",
376         "TANH",
377         "BATCH_TO_SPACE_ND",
378         "DIV",
379         "MEAN",
380         "PAD",
381         "SPACE_TO_BATCH_ND",
382         "SQUEEZE",
383         "STRIDED_SLICE",
384         "SUB",
385         "TRANSPOSE",
386         "ABS",
387         "ARGMAX",
388         "ARGMIN",
389         "AXIS_ALIGNED_BBOX_TRANSFORM",
390         "BIDIRECTIONAL_SEQUENCE_LSTM",
391         "BIDIRECTIONAL_SEQUENCE_RNN",
392         "BOX_WITH_NMS_LIMIT",
393         "CAST",
394         "CHANNEL_SHUFFLE",
395         "DETECTION_POSTPROCESSING",
396         "EQUAL",
397         "EXP",
398         "EXPAND_DIMS",
399         "GATHER",
400         "GENERATE_PROPOSALS",
401         "GREATER",
402         "GREATER_EQUAL",
403         "GROUPED_CONV_2D",
404         "HEATMAP_MAX_KEYPOINT",
405         "INSTANCE_NORMALIZATION",
406         "LESS",
407         "LESS_EQUAL",
408         "LOG",
409         "LOGICAL_AND",
410         "LOGICAL_NOT",
411         "LOGICAL_OR",
412         "LOG_SOFTMAX",
413         "MAXIMUM",
414         "MINIMUM",
415         "NEG",
416         "NOT_EQUAL",
417         "PAD_V2",
418         "POW",
419         "PRELU",
420         "QUANTIZE",
421         "QUANTIZED_16BIT_LSTM",
422         "RANDOM_MULTINOMIAL",
423         "REDUCE_ALL",
424         "REDUCE_ANY",
425         "REDUCE_MAX",
426         "REDUCE_MIN",
427         "REDUCE_PROD",
428         "REDUCE_SUM",
429         "ROI_ALIGN",
430         "ROI_POOLING",
431         "RSQRT",
432         "SELECT",
433         "SIN",
434         "SLICE",
435         "SPLIT",
436         "SQRT",
437         "TILE",
438         "TOPK_V2",
439         "TRANSPOSE_CONV_2D",
440         "UNIDIRECTIONAL_SEQUENCE_LSTM",
441         "UNIDIRECTIONAL_SEQUENCE_RNN",
442         "RESIZE_NEAREST_NEIGHBOR",
443         "QUANTIZED_LSTM",
444         "IF",
445         "WHILE",
446         "ELU",
447         "HARD_SWISH",
448         "FILL",
449         "RANK",
450 };
451 
452 const char* kOperandTypeNames[] = {
453         "FLOAT32",
454         "INT32",
455         "UINT32",
456         "TENSOR_FLOAT32",
457         "TENSOR_INT32",
458         "TENSOR_QUANT8_ASYMM",
459         "BOOL",
460         "TENSOR_QUANT16_SYMM",
461         "TENSOR_FLOAT16",
462         "TENSOR_BOOL8",
463         "FLOAT16",
464         "TENSOR_QUANT8_SYMM_PER_CHANNEL",
465         "TENSOR_QUANT16_ASYMM",
466         "TENSOR_QUANT8_SYMM",
467         "TENSOR_QUANT8_ASYMM_SIGNED",
468 };
469 
isScalarType(TestOperandType type)470 bool isScalarType(TestOperandType type) {
471     static const std::vector<bool> kIsScalarOperandType = {
472             true,   // TestOperandType::FLOAT32
473             true,   // TestOperandType::INT32
474             true,   // TestOperandType::UINT32
475             false,  // TestOperandType::TENSOR_FLOAT32
476             false,  // TestOperandType::TENSOR_INT32
477             false,  // TestOperandType::TENSOR_QUANT8_ASYMM
478             true,   // TestOperandType::BOOL
479             false,  // TestOperandType::TENSOR_QUANT16_SYMM
480             false,  // TestOperandType::TENSOR_FLOAT16
481             false,  // TestOperandType::TENSOR_BOOL8
482             true,   // TestOperandType::FLOAT16
483             false,  // TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL
484             false,  // TestOperandType::TENSOR_QUANT16_ASYMM
485             false,  // TestOperandType::TENSOR_QUANT8_SYMM
486             false,  // TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED
487     };
488     return kIsScalarOperandType[static_cast<int>(type)];
489 }
490 
getOperandClassInSpecFile(TestOperandLifeTime lifetime)491 std::string getOperandClassInSpecFile(TestOperandLifeTime lifetime) {
492     switch (lifetime) {
493         case TestOperandLifeTime::SUBGRAPH_INPUT:
494             return "Input";
495         case TestOperandLifeTime::SUBGRAPH_OUTPUT:
496             return "Output";
497         case TestOperandLifeTime::CONSTANT_COPY:
498         case TestOperandLifeTime::CONSTANT_REFERENCE:
499         case TestOperandLifeTime::NO_VALUE:
500             return "Parameter";
501         case TestOperandLifeTime::TEMPORARY_VARIABLE:
502             return "Internal";
503         default:
504             CHECK(false);
505             return "";
506     }
507 }
508 
509 template <typename T>
defaultToStringFunc(const T & value)510 std::string defaultToStringFunc(const T& value) {
511     return std::to_string(value);
512 };
513 template <>
defaultToStringFunc(const _Float16 & value)514 std::string defaultToStringFunc<_Float16>(const _Float16& value) {
515     return defaultToStringFunc(static_cast<float>(value));
516 };
517 
518 // Dump floating point values in hex representation.
519 template <typename T>
520 std::string toHexFloatString(const T& value);
521 template <>
toHexFloatString(const float & value)522 std::string toHexFloatString<float>(const float& value) {
523     std::stringstream ss;
524     ss << "\"" << std::hexfloat << value << "\"";
525     return ss.str();
526 };
527 template <>
toHexFloatString(const _Float16 & value)528 std::string toHexFloatString<_Float16>(const _Float16& value) {
529     return toHexFloatString(static_cast<float>(value));
530 };
531 
532 template <typename Iterator, class ToStringFunc>
join(const std::string & joint,Iterator begin,Iterator end,ToStringFunc func)533 std::string join(const std::string& joint, Iterator begin, Iterator end, ToStringFunc func) {
534     std::stringstream ss;
535     for (auto it = begin; it < end; it++) {
536         ss << (it == begin ? "" : joint) << func(*it);
537     }
538     return ss.str();
539 }
540 
541 template <typename T, class ToStringFunc>
join(const std::string & joint,const std::vector<T> & range,ToStringFunc func)542 std::string join(const std::string& joint, const std::vector<T>& range, ToStringFunc func) {
543     return join(joint, range.begin(), range.end(), func);
544 }
545 
546 template <typename T>
dumpTestBufferToSpecFileHelper(const TestBuffer & buffer,bool useHexFloat,std::ostream & os)547 void dumpTestBufferToSpecFileHelper(const TestBuffer& buffer, bool useHexFloat, std::ostream& os) {
548     const T* data = buffer.get<T>();
549     const uint32_t length = buffer.size() / sizeof(T);
550     if constexpr (nnIsFloat<T>) {
551         if (useHexFloat) {
552             os << "from_hex([" << join(", ", data, data + length, toHexFloatString<T>) << "])";
553             return;
554         }
555     }
556     os << "[" << join(", ", data, data + length, defaultToStringFunc<T>) << "]";
557 }
558 
559 }  // namespace
560 
toString(TestOperandType type)561 const char* toString(TestOperandType type) {
562     return kOperandTypeNames[static_cast<int>(type)];
563 }
564 
toString(TestOperationType type)565 const char* toString(TestOperationType type) {
566     return kOperationTypeNames[static_cast<int>(type)];
567 }
568 
569 // Dump a test buffer.
dumpTestBuffer(TestOperandType type,const TestBuffer & buffer,bool useHexFloat)570 void SpecDumper::dumpTestBuffer(TestOperandType type, const TestBuffer& buffer, bool useHexFloat) {
571     switch (type) {
572         case TestOperandType::FLOAT32:
573         case TestOperandType::TENSOR_FLOAT32:
574             dumpTestBufferToSpecFileHelper<float>(buffer, useHexFloat, mOs);
575             break;
576         case TestOperandType::INT32:
577         case TestOperandType::TENSOR_INT32:
578             dumpTestBufferToSpecFileHelper<int32_t>(buffer, useHexFloat, mOs);
579             break;
580         case TestOperandType::TENSOR_QUANT8_ASYMM:
581             dumpTestBufferToSpecFileHelper<uint8_t>(buffer, useHexFloat, mOs);
582             break;
583         case TestOperandType::TENSOR_QUANT8_SYMM:
584         case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
585             dumpTestBufferToSpecFileHelper<int8_t>(buffer, useHexFloat, mOs);
586             break;
587         case TestOperandType::TENSOR_QUANT16_ASYMM:
588             dumpTestBufferToSpecFileHelper<uint16_t>(buffer, useHexFloat, mOs);
589             break;
590         case TestOperandType::TENSOR_QUANT16_SYMM:
591             dumpTestBufferToSpecFileHelper<int16_t>(buffer, useHexFloat, mOs);
592             break;
593         case TestOperandType::BOOL:
594         case TestOperandType::TENSOR_BOOL8:
595             dumpTestBufferToSpecFileHelper<bool8>(buffer, useHexFloat, mOs);
596             break;
597         case TestOperandType::FLOAT16:
598         case TestOperandType::TENSOR_FLOAT16:
599             dumpTestBufferToSpecFileHelper<_Float16>(buffer, useHexFloat, mOs);
600             break;
601         default:
602             CHECK(false) << "Unknown type when dumping the buffer";
603     }
604 }
605 
dumpTestOperand(const TestOperand & operand,uint32_t index)606 void SpecDumper::dumpTestOperand(const TestOperand& operand, uint32_t index) {
607     mOs << "op" << index << " = " << getOperandClassInSpecFile(operand.lifetime) << "(\"op" << index
608         << "\", [\"" << toString(operand.type) << "\", ["
609         << join(", ", operand.dimensions, defaultToStringFunc<uint32_t>) << "]";
610     if (operand.scale != 0.0f || operand.zeroPoint != 0) {
611         mOs << ", float.fromhex(" << toHexFloatString(operand.scale) << "), " << operand.zeroPoint;
612     }
613     mOs << "]";
614     if (operand.lifetime == TestOperandLifeTime::CONSTANT_COPY ||
615         operand.lifetime == TestOperandLifeTime::CONSTANT_REFERENCE) {
616         mOs << ", ";
617         dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/true);
618     } else if (operand.lifetime == TestOperandLifeTime::NO_VALUE) {
619         mOs << ", value=None";
620     }
621     mOs << ")";
622     // For quantized data types, append a human-readable scale at the end.
623     if (operand.scale != 0.0f) {
624         mOs << "  # scale = " << operand.scale;
625     }
626     // For float buffers, append human-readable values at the end.
627     if (isFloatType(operand.type) &&
628         (operand.lifetime == TestOperandLifeTime::CONSTANT_COPY ||
629          operand.lifetime == TestOperandLifeTime::CONSTANT_REFERENCE)) {
630         mOs << "  # ";
631         dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/false);
632     }
633     mOs << "\n";
634 }
635 
dumpTestOperation(const TestOperation & operation)636 void SpecDumper::dumpTestOperation(const TestOperation& operation) {
637     auto toOperandName = [](uint32_t index) { return "op" + std::to_string(index); };
638     mOs << "model = model.Operation(\"" << toString(operation.type) << "\", "
639         << join(", ", operation.inputs, toOperandName) << ").To("
640         << join(", ", operation.outputs, toOperandName) << ")\n";
641 }
642 
dumpTestModel()643 void SpecDumper::dumpTestModel() {
644     CHECK_EQ(kTestModel.referenced.size(), 0u) << "Subgraphs not supported";
645     mOs << "from_hex = lambda l: [float.fromhex(i) for i in l]\n\n";
646 
647     // Dump model operands.
648     mOs << "# Model operands\n";
649     for (uint32_t i = 0; i < kTestModel.main.operands.size(); i++) {
650         dumpTestOperand(kTestModel.main.operands[i], i);
651     }
652 
653     // Dump model operations.
654     mOs << "\n# Model operations\nmodel = Model()\n";
655     for (const auto& operation : kTestModel.main.operations) {
656         dumpTestOperation(operation);
657     }
658 
659     // Dump input/output buffers.
660     mOs << "\n# Example\nExample({\n";
661     for (uint32_t i = 0; i < kTestModel.main.operands.size(); i++) {
662         const auto& operand = kTestModel.main.operands[i];
663         if (operand.lifetime != TestOperandLifeTime::SUBGRAPH_INPUT &&
664             operand.lifetime != TestOperandLifeTime::SUBGRAPH_OUTPUT) {
665             continue;
666         }
667         // For float buffers, dump human-readable values as a comment.
668         if (isFloatType(operand.type)) {
669             mOs << "    # op" << i << ": ";
670             dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/false);
671             mOs << "\n";
672         }
673         mOs << "    op" << i << ": ";
674         dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/true);
675         mOs << ",\n";
676     }
677     mOs << "}).DisableLifeTimeVariation()\n";
678 }
679 
dumpResults(const std::string & name,const std::vector<TestBuffer> & results)680 void SpecDumper::dumpResults(const std::string& name, const std::vector<TestBuffer>& results) {
681     CHECK_EQ(results.size(), kTestModel.main.outputIndexes.size());
682     mOs << "\n# Results from " << name << "\n{\n";
683     for (uint32_t i = 0; i < results.size(); i++) {
684         const uint32_t outputIndex = kTestModel.main.outputIndexes[i];
685         const auto& operand = kTestModel.main.operands[outputIndex];
686         // For float buffers, dump human-readable values as a comment.
687         if (isFloatType(operand.type)) {
688             mOs << "    # op" << outputIndex << ": ";
689             dumpTestBuffer(operand.type, results[i], /*useHexFloat=*/false);
690             mOs << "\n";
691         }
692         mOs << "    op" << outputIndex << ": ";
693         dumpTestBuffer(operand.type, results[i], /*useHexFloat=*/true);
694         mOs << ",\n";
695     }
696     mOs << "}\n";
697 }
698 
699 template <typename T>
convertOperandToFloat32(const TestOperand & op)700 static TestOperand convertOperandToFloat32(const TestOperand& op) {
701     TestOperand converted = op;
702     converted.type =
703             isScalarType(op.type) ? TestOperandType::FLOAT32 : TestOperandType::TENSOR_FLOAT32;
704     converted.scale = 0.0f;
705     converted.zeroPoint = 0;
706 
707     const uint32_t numberOfElements = getNumberOfElements(converted);
708     converted.data = TestBuffer(numberOfElements * sizeof(float));
709     const T* data = op.data.get<T>();
710     float* floatData = converted.data.getMutable<float>();
711 
712     if (op.scale != 0.0f) {
713         std::transform(data, data + numberOfElements, floatData, [&op](T val) {
714             return (static_cast<float>(val) - op.zeroPoint) * op.scale;
715         });
716     } else {
717         std::transform(data, data + numberOfElements, floatData,
718                        [](T val) { return static_cast<float>(val); });
719     }
720     return converted;
721 }
722 
convertToFloat32Model(const TestModel & testModel)723 std::optional<TestModel> convertToFloat32Model(const TestModel& testModel) {
724     // Only single-operation graphs are supported.
725     if (testModel.referenced.size() > 0 || testModel.main.operations.size() > 1) {
726         return std::nullopt;
727     }
728 
729     // Check for unsupported operations.
730     CHECK(!testModel.main.operations.empty());
731     const auto& operation = testModel.main.operations[0];
732     // Do not convert type-casting operations.
733     if (operation.type == TestOperationType::DEQUANTIZE ||
734         operation.type == TestOperationType::QUANTIZE ||
735         operation.type == TestOperationType::CAST) {
736         return std::nullopt;
737     }
738     // HASHTABLE_LOOKUP has different behavior in float and quant data types: float
739     // HASHTABLE_LOOKUP will output logical zero when there is a key miss, while quant
740     // HASHTABLE_LOOKUP will output byte zero.
741     if (operation.type == TestOperationType::HASHTABLE_LOOKUP) {
742         return std::nullopt;
743     }
744 
745     auto convert = [&testModel, &operation](const TestOperand& op, uint32_t index) {
746         switch (op.type) {
747             case TestOperandType::TENSOR_FLOAT32:
748             case TestOperandType::FLOAT32:
749             case TestOperandType::TENSOR_BOOL8:
750             case TestOperandType::BOOL:
751             case TestOperandType::UINT32:
752                 return op;
753             case TestOperandType::INT32:
754                 // The third input of PAD_V2 uses INT32 to specify the padded value.
755                 if (operation.type == TestOperationType::PAD_V2 && index == operation.inputs[2]) {
756                     // The scale and zero point is inherited from the first input.
757                     const uint32_t input0Index = operation.inputs[0];
758                     const auto& input0 = testModel.main.operands[input0Index];
759                     TestOperand scalarWithScaleAndZeroPoint = op;
760                     scalarWithScaleAndZeroPoint.scale = input0.scale;
761                     scalarWithScaleAndZeroPoint.zeroPoint = input0.zeroPoint;
762                     return convertOperandToFloat32<int32_t>(scalarWithScaleAndZeroPoint);
763                 }
764                 return op;
765             case TestOperandType::TENSOR_INT32:
766                 if (op.scale != 0.0f || op.zeroPoint != 0) {
767                     return convertOperandToFloat32<int32_t>(op);
768                 }
769                 return op;
770             case TestOperandType::TENSOR_FLOAT16:
771             case TestOperandType::FLOAT16:
772                 return convertOperandToFloat32<_Float16>(op);
773             case TestOperandType::TENSOR_QUANT8_ASYMM:
774                 return convertOperandToFloat32<uint8_t>(op);
775             case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
776                 return convertOperandToFloat32<int8_t>(op);
777             case TestOperandType::TENSOR_QUANT16_ASYMM:
778                 return convertOperandToFloat32<uint16_t>(op);
779             case TestOperandType::TENSOR_QUANT16_SYMM:
780                 return convertOperandToFloat32<int16_t>(op);
781             default:
782                 CHECK(false) << "OperandType not supported";
783                 return TestOperand{};
784         }
785     };
786 
787     TestModel converted = testModel;
788     for (uint32_t i = 0; i < testModel.main.operands.size(); i++) {
789         converted.main.operands[i] = convert(testModel.main.operands[i], i);
790     }
791     return converted;
792 }
793 
794 template <typename T>
setDataFromFloat32Buffer(const TestBuffer & fpBuffer,TestOperand * op)795 static void setDataFromFloat32Buffer(const TestBuffer& fpBuffer, TestOperand* op) {
796     const uint32_t numberOfElements = getNumberOfElements(*op);
797     const float* floatData = fpBuffer.get<float>();
798     T* data = op->data.getMutable<T>();
799 
800     if (op->scale != 0.0f) {
801         std::transform(floatData, floatData + numberOfElements, data, [op](float val) {
802             int32_t unclamped = std::round(val / op->scale) + op->zeroPoint;
803             int32_t clamped = std::clamp<int32_t>(unclamped, std::numeric_limits<T>::min(),
804                                                   std::numeric_limits<T>::max());
805             return static_cast<T>(clamped);
806         });
807     } else {
808         std::transform(floatData, floatData + numberOfElements, data,
809                        [](float val) { return static_cast<T>(val); });
810     }
811 }
812 
setExpectedOutputsFromFloat32Results(const std::vector<TestBuffer> & results,TestModel * model)813 void setExpectedOutputsFromFloat32Results(const std::vector<TestBuffer>& results,
814                                           TestModel* model) {
815     CHECK_EQ(model->referenced.size(), 0u) << "Subgraphs not supported";
816     CHECK_EQ(model->main.operations.size(), 1u) << "Only single-operation graph is supported";
817 
818     for (uint32_t i = 0; i < results.size(); i++) {
819         uint32_t outputIndex = model->main.outputIndexes[i];
820         auto& op = model->main.operands[outputIndex];
821         switch (op.type) {
822             case TestOperandType::TENSOR_FLOAT32:
823             case TestOperandType::FLOAT32:
824             case TestOperandType::TENSOR_BOOL8:
825             case TestOperandType::BOOL:
826             case TestOperandType::INT32:
827             case TestOperandType::UINT32:
828                 op.data = results[i];
829                 break;
830             case TestOperandType::TENSOR_INT32:
831                 if (op.scale != 0.0f) {
832                     setDataFromFloat32Buffer<int32_t>(results[i], &op);
833                 } else {
834                     op.data = results[i];
835                 }
836                 break;
837             case TestOperandType::TENSOR_FLOAT16:
838             case TestOperandType::FLOAT16:
839                 setDataFromFloat32Buffer<_Float16>(results[i], &op);
840                 break;
841             case TestOperandType::TENSOR_QUANT8_ASYMM:
842                 setDataFromFloat32Buffer<uint8_t>(results[i], &op);
843                 break;
844             case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
845                 setDataFromFloat32Buffer<int8_t>(results[i], &op);
846                 break;
847             case TestOperandType::TENSOR_QUANT16_ASYMM:
848                 setDataFromFloat32Buffer<uint16_t>(results[i], &op);
849                 break;
850             case TestOperandType::TENSOR_QUANT16_SYMM:
851                 setDataFromFloat32Buffer<int16_t>(results[i], &op);
852                 break;
853             default:
854                 CHECK(false) << "OperandType not supported";
855         }
856     }
857 }
858 
859 }  // namespace test_helper
860