1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
17 #include "TestHarness.h"
19 #include <android-base/logging.h>
20 #include <gmock/gmock-matchers.h>
21 #include <gtest/gtest.h>
23 #include <algorithm>
24 #include <cmath>
25 #include <functional>
26 #include <limits>
27 #include <map>
28 #include <numeric>
29 #include <set>
30 #include <string>
31 #include <vector>
33 namespace test_helper {
35 namespace {
37 template <typename T>
38 constexpr bool nnIsFloat = std::is_floating_point_v<T> || std::is_same_v<T, _Float16>;
40 constexpr uint32_t kMaxNumberOfPrintedErrors = 10;
42 // TODO(b/139442217): Allow passing accuracy criteria from spec.
43 // Currently we only need relaxed accuracy criteria on mobilenet tests, so we return the quant8
44 // tolerance simply based on the current test name.
getQuant8AllowedError()45 int getQuant8AllowedError() {
46 const ::testing::TestInfo* const testInfo =
47 ::testing::UnitTest::GetInstance()->current_test_info();
48 const std::string testCaseName = testInfo->test_case_name();
49 const std::string testName = testInfo->name();
50 // We relax the quant8 precision for all tests with mobilenet:
51 // - CTS/VTS GeneratedTest and DynamicOutputShapeTest with mobilenet
52 // - VTS CompilationCachingTest and CompilationCachingSecurityTest except for TOCTOU tests
53 if (testName.find("mobilenet") != std::string::npos ||
54 (testCaseName.find("CompilationCaching") != std::string::npos &&
55 testName.find("TOCTOU") == std::string::npos)) {
56 return 3;
57 } else {
58 return 1;
59 }
60 }
getNumberOfElements(const TestOperand & op)62 uint32_t getNumberOfElements(const TestOperand& op) {
63 return std::reduce(op.dimensions.begin(), op.dimensions.end(), 1u, std::multiplies<uint32_t>());
64 }
66 // Check if the actual results meet the accuracy criterion.
67 template <typename T>
expectNear(const TestOperand & op,const TestBuffer & result,const AccuracyCriterion & criterion,bool allowInvalid=false)68 void expectNear(const TestOperand& op, const TestBuffer& result, const AccuracyCriterion& criterion,
69 bool allowInvalid = false) {
70 constexpr uint32_t kMinNumberOfElementsToTestBiasMSE = 10;
71 const T* actualBuffer = result.get<T>();
72 const T* expectedBuffer = op.data.get<T>();
73 uint32_t len = getNumberOfElements(op), numErrors = 0, numSkip = 0;
74 double bias = 0.0f, mse = 0.0f;
75 for (uint32_t i = 0; i < len; i++) {
76 // Compare all data types in double for precision and signed arithmetic.
77 double actual = static_cast<double>(actualBuffer[i]);
78 double expected = static_cast<double>(expectedBuffer[i]);
79 double tolerableRange = criterion.atol + criterion.rtol * std::fabs(expected);
80 EXPECT_FALSE(std::isnan(expected));
82 // Skip invalid floating point values.
83 if (allowInvalid &&
84 (std::isinf(expected) || (std::is_same_v<T, float> && std::fabs(expected) > 1e3) ||
85 (std::is_same_v<T, _Float16> || std::fabs(expected) > 1e2))) {
86 numSkip++;
87 continue;
88 }
90 // Accumulate bias and MSE. Use relative bias and MSE for floating point values.
91 double diff = actual - expected;
92 if constexpr (nnIsFloat<T>) {
93 diff /= std::max(1.0, std::abs(expected));
94 }
95 bias += diff;
96 mse += diff * diff;
98 // Print at most kMaxNumberOfPrintedErrors errors by EXPECT_NEAR.
99 if (numErrors < kMaxNumberOfPrintedErrors) {
100 EXPECT_NEAR(expected, actual, tolerableRange) << "When comparing element " << i;
101 }
102 if (std::fabs(actual - expected) > tolerableRange) numErrors++;
103 }
104 EXPECT_EQ(numErrors, 0u);
106 // Test bias and MSE.
107 if (len < numSkip + kMinNumberOfElementsToTestBiasMSE) return;
108 bias /= static_cast<double>(len - numSkip);
109 mse /= static_cast<double>(len - numSkip);
110 EXPECT_LE(std::fabs(bias), criterion.bias);
111 EXPECT_LE(mse, criterion.mse);
112 }
114 // For boolean values, we expect the number of mismatches does not exceed a certain ratio.
expectBooleanNearlyEqual(const TestOperand & op,const TestBuffer & result,float allowedErrorRatio)115 void expectBooleanNearlyEqual(const TestOperand& op, const TestBuffer& result,
116 float allowedErrorRatio) {
117 const bool8* actualBuffer = result.get<bool8>();
118 const bool8* expectedBuffer = op.data.get<bool8>();
119 uint32_t len = getNumberOfElements(op), numErrors = 0;
120 std::stringstream errorMsg;
121 for (uint32_t i = 0; i < len; i++) {
122 if (expectedBuffer[i] != actualBuffer[i]) {
123 if (numErrors < kMaxNumberOfPrintedErrors)
124 errorMsg << " Expected: " << expectedBuffer[i] << ", actual: " << actualBuffer[i]
125 << ", when comparing element " << i << "\n";
126 numErrors++;
127 }
128 }
129 // When |len| is small, the allowedErrorCount will intentionally ceil at 1, which allows for
130 // greater tolerance.
131 uint32_t allowedErrorCount = static_cast<uint32_t>(std::ceil(allowedErrorRatio * len));
132 EXPECT_LE(numErrors, allowedErrorCount) << errorMsg.str();
133 }
135 // Calculates the expected probability from the unnormalized log-probability of
136 // each class in the input and compares it to the actual occurrence of that class
137 // in the output.
expectMultinomialDistributionWithinTolerance(const TestModel & model,const std::vector<TestBuffer> & buffers)138 void expectMultinomialDistributionWithinTolerance(const TestModel& model,
139 const std::vector<TestBuffer>& buffers) {
140 // This function is only for RANDOM_MULTINOMIAL single-operation test.
141 CHECK_EQ(model.referenced.size(), 0u) << "Subgraphs not supported";
142 ASSERT_EQ(model.main.operations.size(), 1u);
143 ASSERT_EQ(model.main.operations[0].type, TestOperationType::RANDOM_MULTINOMIAL);
144 ASSERT_EQ(model.main.inputIndexes.size(), 1u);
145 ASSERT_EQ(model.main.outputIndexes.size(), 1u);
146 ASSERT_EQ(buffers.size(), 1u);
148 const auto& inputOperand = model.main.operands[model.main.inputIndexes[0]];
149 const auto& outputOperand = model.main.operands[model.main.outputIndexes[0]];
150 ASSERT_EQ(inputOperand.dimensions.size(), 2u);
151 ASSERT_EQ(outputOperand.dimensions.size(), 2u);
153 const int kBatchSize = inputOperand.dimensions[0];
154 const int kNumClasses = inputOperand.dimensions[1];
155 const int kNumSamples = outputOperand.dimensions[1];
157 const uint32_t outputLength = getNumberOfElements(outputOperand);
158 const int32_t* outputData = buffers[0].get<int32_t>();
159 std::vector<int> classCounts(kNumClasses);
160 for (uint32_t i = 0; i < outputLength; i++) {
161 classCounts[outputData[i]]++;
162 }
164 const uint32_t inputLength = getNumberOfElements(inputOperand);
165 std::vector<float> inputData(inputLength);
166 if (inputOperand.type == TestOperandType::TENSOR_FLOAT32) {
167 const float* inputRaw = inputOperand.data.get<float>();
168 std::copy(inputRaw, inputRaw + inputLength, inputData.begin());
169 } else if (inputOperand.type == TestOperandType::TENSOR_FLOAT16) {
170 const _Float16* inputRaw = inputOperand.data.get<_Float16>();
171 std::transform(inputRaw, inputRaw + inputLength, inputData.begin(),
172 [](_Float16 fp16) { return static_cast<float>(fp16); });
173 } else {
174 FAIL() << "Unknown input operand type for RANDOM_MULTINOMIAL.";
175 }
177 for (int b = 0; b < kBatchSize; ++b) {
178 float probabilitySum = 0;
179 const int batchIndex = kBatchSize * b;
180 for (int i = 0; i < kNumClasses; ++i) {
181 probabilitySum += expf(inputData[batchIndex + i]);
182 }
183 for (int i = 0; i < kNumClasses; ++i) {
184 float probability =
185 static_cast<float>(classCounts[i]) / static_cast<float>(kNumSamples);
186 float probabilityExpected = expf(inputData[batchIndex + i]) / probabilitySum;
187 EXPECT_THAT(probability,
188 ::testing::FloatNear(probabilityExpected,
189 model.expectedMultinomialDistributionTolerance));
190 }
191 }
192 }
194 } // namespace
checkResults(const TestModel & model,const std::vector<TestBuffer> & buffers,const AccuracyCriteria & criteria)196 void checkResults(const TestModel& model, const std::vector<TestBuffer>& buffers,
197 const AccuracyCriteria& criteria) {
198 ASSERT_EQ(model.main.outputIndexes.size(), buffers.size());
199 for (uint32_t i = 0; i < model.main.outputIndexes.size(); i++) {
200 const uint32_t outputIndex = model.main.outputIndexes[i];
201 SCOPED_TRACE(testing::Message()
202 << "When comparing output " << i << " (op" << outputIndex << ")");
203 const auto& operand = model.main.operands[outputIndex];
204 const auto& result = buffers[i];
205 if (operand.isIgnored) continue;
207 switch (operand.type) {
208 case TestOperandType::TENSOR_FLOAT32:
209 expectNear<float>(operand, result, criteria.float32, criteria.allowInvalidFpValues);
210 break;
211 case TestOperandType::TENSOR_FLOAT16:
212 expectNear<_Float16>(operand, result, criteria.float16,
213 criteria.allowInvalidFpValues);
214 break;
215 case TestOperandType::TENSOR_INT32:
216 case TestOperandType::INT32:
217 expectNear<int32_t>(operand, result, criteria.int32);
218 break;
219 case TestOperandType::TENSOR_QUANT8_ASYMM:
220 expectNear<uint8_t>(operand, result, criteria.quant8Asymm);
221 break;
222 case TestOperandType::TENSOR_QUANT8_SYMM:
223 expectNear<int8_t>(operand, result, criteria.quant8Symm);
224 break;
225 case TestOperandType::TENSOR_QUANT16_ASYMM:
226 expectNear<uint16_t>(operand, result, criteria.quant16Asymm);
227 break;
228 case TestOperandType::TENSOR_QUANT16_SYMM:
229 expectNear<int16_t>(operand, result, criteria.quant16Symm);
230 break;
231 case TestOperandType::TENSOR_BOOL8:
232 expectBooleanNearlyEqual(operand, result, criteria.bool8AllowedErrorRatio);
233 break;
234 case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
235 expectNear<int8_t>(operand, result, criteria.quant8AsymmSigned);
236 break;
237 default:
238 FAIL() << "Data type not supported.";
239 }
240 }
241 }
checkResults(const TestModel & model,const std::vector<TestBuffer> & buffers)243 void checkResults(const TestModel& model, const std::vector<TestBuffer>& buffers) {
244 // For RANDOM_MULTINOMIAL test only.
245 if (model.expectedMultinomialDistributionTolerance > 0.0f) {
246 expectMultinomialDistributionWithinTolerance(model, buffers);
247 return;
248 }
250 // Decide the default tolerable range.
251 //
252 // For floating-point models, we use the relaxed precision if either
253 // - relaxed computation flag is set
254 // - the model has at least one TENSOR_FLOAT16 operand
255 //
256 // The bias and MSE criteria are implicitly set to the maximum -- we do not enforce these
257 // criteria in normal generated tests.
258 //
259 // TODO: Adjust the error limit based on testing.
260 //
261 AccuracyCriteria criteria = {
262 // The relative tolerance is 5ULP of FP32.
263 .float32 = {.atol = 1e-5, .rtol = 5.0f * 1.1920928955078125e-7},
264 // Both the absolute and relative tolerance are 5ULP of FP16.
265 .float16 = {.atol = 5.0f * 0.0009765625, .rtol = 5.0f * 0.0009765625},
266 .int32 = {.atol = 1},
267 .quant8Asymm = {.atol = 1},
268 .quant8Symm = {.atol = 1},
269 .quant16Asymm = {.atol = 1},
270 .quant16Symm = {.atol = 1},
271 .bool8AllowedErrorRatio = 0.0f,
272 // Since generated tests are hand-calculated, there should be no invalid FP values.
273 .allowInvalidFpValues = false,
274 };
275 bool hasFloat16Inputs = false;
276 model.forEachSubgraph([&hasFloat16Inputs](const TestSubgraph& subgraph) {
277 if (!hasFloat16Inputs) {
278 hasFloat16Inputs = std::any_of(subgraph.operands.begin(), subgraph.operands.end(),
279 [](const TestOperand& op) {
280 return op.type == TestOperandType::TENSOR_FLOAT16;
281 });
282 }
283 });
284 if (model.isRelaxed || hasFloat16Inputs) {
285 criteria.float32 = criteria.float16;
286 }
287 const double quant8AllowedError = getQuant8AllowedError();
288 criteria.quant8Asymm.atol = quant8AllowedError;
289 criteria.quant8AsymmSigned.atol = quant8AllowedError;
290 criteria.quant8Symm.atol = quant8AllowedError;
292 checkResults(model, buffers, criteria);
293 }
convertQuant8AsymmOperandsToSigned(const TestModel & testModel)295 TestModel convertQuant8AsymmOperandsToSigned(const TestModel& testModel) {
296 auto processSubgraph = [](TestSubgraph* subgraph) {
297 for (TestOperand& operand : subgraph->operands) {
298 if (operand.type == test_helper::TestOperandType::TENSOR_QUANT8_ASYMM) {
299 operand.type = test_helper::TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED;
300 operand.zeroPoint -= 128;
301 const uint8_t* inputOperandData = operand.data.get<uint8_t>();
302 int8_t* outputOperandData = operand.data.getMutable<int8_t>();
303 for (size_t i = 0; i < operand.data.size(); ++i) {
304 outputOperandData[i] =
305 static_cast<int8_t>(static_cast<int32_t>(inputOperandData[i]) - 128);
306 }
307 }
308 }
309 };
310 TestModel converted(testModel.copy());
311 processSubgraph(&converted.main);
312 for (TestSubgraph& subgraph : converted.referenced) {
313 processSubgraph(&subgraph);
314 }
315 return converted;
316 }
isQuantizedType(TestOperandType type)318 bool isQuantizedType(TestOperandType type) {
319 static const std::set<TestOperandType> kQuantizedTypes = {
320 TestOperandType::TENSOR_QUANT8_ASYMM,
321 TestOperandType::TENSOR_QUANT8_SYMM,
322 TestOperandType::TENSOR_QUANT16_ASYMM,
323 TestOperandType::TENSOR_QUANT16_SYMM,
326 };
327 return kQuantizedTypes.count(type) > 0;
328 }
isFloatType(TestOperandType type)330 bool isFloatType(TestOperandType type) {
331 static const std::set<TestOperandType> kFloatTypes = {
332 TestOperandType::TENSOR_FLOAT32,
333 TestOperandType::TENSOR_FLOAT16,
334 TestOperandType::FLOAT32,
335 TestOperandType::FLOAT16,
336 };
337 return kFloatTypes.count(type) > 0;
338 }
isConstant(TestOperandLifeTime lifetime)340 bool isConstant(TestOperandLifeTime lifetime) {
341 return lifetime == TestOperandLifeTime::CONSTANT_COPY ||
342 lifetime == TestOperandLifeTime::CONSTANT_REFERENCE;
343 }
345 namespace {
347 const char* kOperationTypeNames[] = {
348 "ADD",
351 "CONV_2D",
356 "FLOOR",
360 "L2_POOL",
364 "LSTM",
365 "MAX_POOL_2D",
366 "MUL",
367 "RELU",
368 "RELU1",
369 "RELU6",
370 "RESHAPE",
372 "RNN",
373 "SOFTMAX",
375 "SVDF",
376 "TANH",
378 "DIV",
379 "MEAN",
380 "PAD",
382 "SQUEEZE",
384 "SUB",
386 "ABS",
387 "ARGMAX",
388 "ARGMIN",
393 "CAST",
396 "EQUAL",
397 "EXP",
399 "GATHER",
401 "GREATER",
406 "LESS",
408 "LOG",
413 "MAXIMUM",
414 "MINIMUM",
415 "NEG",
416 "NOT_EQUAL",
417 "PAD_V2",
418 "POW",
419 "PRELU",
429 "ROI_ALIGN",
431 "RSQRT",
432 "SELECT",
433 "SIN",
434 "SLICE",
435 "SPLIT",
436 "SQRT",
437 "TILE",
438 "TOPK_V2",
444 "IF",
445 "WHILE",
446 "ELU",
448 "FILL",
449 "RANK",
450 };
452 const char* kOperandTypeNames[] = {
453 "FLOAT32",
454 "INT32",
455 "UINT32",
457 "TENSOR_INT32",
459 "BOOL",
463 "FLOAT16",
468 };
isScalarType(TestOperandType type)470 bool isScalarType(TestOperandType type) {
471 static const std::vector<bool> kIsScalarOperandType = {
472 true, // TestOperandType::FLOAT32
473 true, // TestOperandType::INT32
474 true, // TestOperandType::UINT32
475 false, // TestOperandType::TENSOR_FLOAT32
476 false, // TestOperandType::TENSOR_INT32
477 false, // TestOperandType::TENSOR_QUANT8_ASYMM
478 true, // TestOperandType::BOOL
479 false, // TestOperandType::TENSOR_QUANT16_SYMM
480 false, // TestOperandType::TENSOR_FLOAT16
481 false, // TestOperandType::TENSOR_BOOL8
482 true, // TestOperandType::FLOAT16
483 false, // TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL
484 false, // TestOperandType::TENSOR_QUANT16_ASYMM
485 false, // TestOperandType::TENSOR_QUANT8_SYMM
486 false, // TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED
487 };
488 return kIsScalarOperandType[static_cast<int>(type)];
489 }
getOperandClassInSpecFile(TestOperandLifeTime lifetime)491 std::string getOperandClassInSpecFile(TestOperandLifeTime lifetime) {
492 switch (lifetime) {
493 case TestOperandLifeTime::SUBGRAPH_INPUT:
494 return "Input";
495 case TestOperandLifeTime::SUBGRAPH_OUTPUT:
496 return "Output";
497 case TestOperandLifeTime::CONSTANT_COPY:
498 case TestOperandLifeTime::CONSTANT_REFERENCE:
499 case TestOperandLifeTime::NO_VALUE:
500 return "Parameter";
501 case TestOperandLifeTime::TEMPORARY_VARIABLE:
502 return "Internal";
503 default:
504 CHECK(false);
505 return "";
506 }
507 }
509 template <typename T>
defaultToStringFunc(const T & value)510 std::string defaultToStringFunc(const T& value) {
511 return std::to_string(value);
512 };
513 template <>
defaultToStringFunc(const _Float16 & value)514 std::string defaultToStringFunc<_Float16>(const _Float16& value) {
515 return defaultToStringFunc(static_cast<float>(value));
516 };
518 // Dump floating point values in hex representation.
519 template <typename T>
520 std::string toHexFloatString(const T& value);
521 template <>
toHexFloatString(const float & value)522 std::string toHexFloatString<float>(const float& value) {
523 std::stringstream ss;
524 ss << "\"" << std::hexfloat << value << "\"";
525 return ss.str();
526 };
527 template <>
toHexFloatString(const _Float16 & value)528 std::string toHexFloatString<_Float16>(const _Float16& value) {
529 return toHexFloatString(static_cast<float>(value));
530 };
532 template <typename Iterator, class ToStringFunc>
join(const std::string & joint,Iterator begin,Iterator end,ToStringFunc func)533 std::string join(const std::string& joint, Iterator begin, Iterator end, ToStringFunc func) {
534 std::stringstream ss;
535 for (auto it = begin; it < end; it++) {
536 ss << (it == begin ? "" : joint) << func(*it);
537 }
538 return ss.str();
539 }
541 template <typename T, class ToStringFunc>
join(const std::string & joint,const std::vector<T> & range,ToStringFunc func)542 std::string join(const std::string& joint, const std::vector<T>& range, ToStringFunc func) {
543 return join(joint, range.begin(), range.end(), func);
544 }
546 template <typename T>
dumpTestBufferToSpecFileHelper(const TestBuffer & buffer,bool useHexFloat,std::ostream & os)547 void dumpTestBufferToSpecFileHelper(const TestBuffer& buffer, bool useHexFloat, std::ostream& os) {
548 const T* data = buffer.get<T>();
549 const uint32_t length = buffer.size() / sizeof(T);
550 if constexpr (nnIsFloat<T>) {
551 if (useHexFloat) {
552 os << "from_hex([" << join(", ", data, data + length, toHexFloatString<T>) << "])";
553 return;
554 }
555 }
556 os << "[" << join(", ", data, data + length, defaultToStringFunc<T>) << "]";
557 }
559 } // namespace
toString(TestOperandType type)561 const char* toString(TestOperandType type) {
562 return kOperandTypeNames[static_cast<int>(type)];
563 }
toString(TestOperationType type)565 const char* toString(TestOperationType type) {
566 return kOperationTypeNames[static_cast<int>(type)];
567 }
569 // Dump a test buffer.
dumpTestBuffer(TestOperandType type,const TestBuffer & buffer,bool useHexFloat)570 void SpecDumper::dumpTestBuffer(TestOperandType type, const TestBuffer& buffer, bool useHexFloat) {
571 switch (type) {
572 case TestOperandType::FLOAT32:
573 case TestOperandType::TENSOR_FLOAT32:
574 dumpTestBufferToSpecFileHelper<float>(buffer, useHexFloat, mOs);
575 break;
576 case TestOperandType::INT32:
577 case TestOperandType::TENSOR_INT32:
578 dumpTestBufferToSpecFileHelper<int32_t>(buffer, useHexFloat, mOs);
579 break;
580 case TestOperandType::TENSOR_QUANT8_ASYMM:
581 dumpTestBufferToSpecFileHelper<uint8_t>(buffer, useHexFloat, mOs);
582 break;
583 case TestOperandType::TENSOR_QUANT8_SYMM:
584 case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
585 dumpTestBufferToSpecFileHelper<int8_t>(buffer, useHexFloat, mOs);
586 break;
587 case TestOperandType::TENSOR_QUANT16_ASYMM:
588 dumpTestBufferToSpecFileHelper<uint16_t>(buffer, useHexFloat, mOs);
589 break;
590 case TestOperandType::TENSOR_QUANT16_SYMM:
591 dumpTestBufferToSpecFileHelper<int16_t>(buffer, useHexFloat, mOs);
592 break;
593 case TestOperandType::BOOL:
594 case TestOperandType::TENSOR_BOOL8:
595 dumpTestBufferToSpecFileHelper<bool8>(buffer, useHexFloat, mOs);
596 break;
597 case TestOperandType::FLOAT16:
598 case TestOperandType::TENSOR_FLOAT16:
599 dumpTestBufferToSpecFileHelper<_Float16>(buffer, useHexFloat, mOs);
600 break;
601 default:
602 CHECK(false) << "Unknown type when dumping the buffer";
603 }
604 }
dumpTestOperand(const TestOperand & operand,uint32_t index)606 void SpecDumper::dumpTestOperand(const TestOperand& operand, uint32_t index) {
607 mOs << "op" << index << " = " << getOperandClassInSpecFile(operand.lifetime) << "(\"op" << index
608 << "\", [\"" << toString(operand.type) << "\", ["
609 << join(", ", operand.dimensions, defaultToStringFunc<uint32_t>) << "]";
610 if (operand.scale != 0.0f || operand.zeroPoint != 0) {
611 mOs << ", float.fromhex(" << toHexFloatString(operand.scale) << "), " << operand.zeroPoint;
612 }
613 mOs << "]";
614 if (operand.lifetime == TestOperandLifeTime::CONSTANT_COPY ||
615 operand.lifetime == TestOperandLifeTime::CONSTANT_REFERENCE) {
616 mOs << ", ";
617 dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/true);
618 } else if (operand.lifetime == TestOperandLifeTime::NO_VALUE) {
619 mOs << ", value=None";
620 }
621 mOs << ")";
622 // For quantized data types, append a human-readable scale at the end.
623 if (operand.scale != 0.0f) {
624 mOs << " # scale = " << operand.scale;
625 }
626 // For float buffers, append human-readable values at the end.
627 if (isFloatType(operand.type) &&
628 (operand.lifetime == TestOperandLifeTime::CONSTANT_COPY ||
629 operand.lifetime == TestOperandLifeTime::CONSTANT_REFERENCE)) {
630 mOs << " # ";
631 dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/false);
632 }
633 mOs << "\n";
634 }
dumpTestOperation(const TestOperation & operation)636 void SpecDumper::dumpTestOperation(const TestOperation& operation) {
637 auto toOperandName = [](uint32_t index) { return "op" + std::to_string(index); };
638 mOs << "model = model.Operation(\"" << toString(operation.type) << "\", "
639 << join(", ", operation.inputs, toOperandName) << ").To("
640 << join(", ", operation.outputs, toOperandName) << ")\n";
641 }
dumpTestModel()643 void SpecDumper::dumpTestModel() {
644 CHECK_EQ(kTestModel.referenced.size(), 0u) << "Subgraphs not supported";
645 mOs << "from_hex = lambda l: [float.fromhex(i) for i in l]\n\n";
647 // Dump model operands.
648 mOs << "# Model operands\n";
649 for (uint32_t i = 0; i < kTestModel.main.operands.size(); i++) {
650 dumpTestOperand(kTestModel.main.operands[i], i);
651 }
653 // Dump model operations.
654 mOs << "\n# Model operations\nmodel = Model()\n";
655 for (const auto& operation : kTestModel.main.operations) {
656 dumpTestOperation(operation);
657 }
659 // Dump input/output buffers.
660 mOs << "\n# Example\nExample({\n";
661 for (uint32_t i = 0; i < kTestModel.main.operands.size(); i++) {
662 const auto& operand = kTestModel.main.operands[i];
663 if (operand.lifetime != TestOperandLifeTime::SUBGRAPH_INPUT &&
664 operand.lifetime != TestOperandLifeTime::SUBGRAPH_OUTPUT) {
665 continue;
666 }
667 // For float buffers, dump human-readable values as a comment.
668 if (isFloatType(operand.type)) {
669 mOs << " # op" << i << ": ";
670 dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/false);
671 mOs << "\n";
672 }
673 mOs << " op" << i << ": ";
674 dumpTestBuffer(operand.type, operand.data, /*useHexFloat=*/true);
675 mOs << ",\n";
676 }
677 mOs << "}).DisableLifeTimeVariation()\n";
678 }
dumpResults(const std::string & name,const std::vector<TestBuffer> & results)680 void SpecDumper::dumpResults(const std::string& name, const std::vector<TestBuffer>& results) {
681 CHECK_EQ(results.size(), kTestModel.main.outputIndexes.size());
682 mOs << "\n# Results from " << name << "\n{\n";
683 for (uint32_t i = 0; i < results.size(); i++) {
684 const uint32_t outputIndex = kTestModel.main.outputIndexes[i];
685 const auto& operand = kTestModel.main.operands[outputIndex];
686 // For float buffers, dump human-readable values as a comment.
687 if (isFloatType(operand.type)) {
688 mOs << " # op" << outputIndex << ": ";
689 dumpTestBuffer(operand.type, results[i], /*useHexFloat=*/false);
690 mOs << "\n";
691 }
692 mOs << " op" << outputIndex << ": ";
693 dumpTestBuffer(operand.type, results[i], /*useHexFloat=*/true);
694 mOs << ",\n";
695 }
696 mOs << "}\n";
697 }
699 template <typename T>
convertOperandToFloat32(const TestOperand & op)700 static TestOperand convertOperandToFloat32(const TestOperand& op) {
701 TestOperand converted = op;
702 converted.type =
703 isScalarType(op.type) ? TestOperandType::FLOAT32 : TestOperandType::TENSOR_FLOAT32;
704 converted.scale = 0.0f;
705 converted.zeroPoint = 0;
707 const uint32_t numberOfElements = getNumberOfElements(converted);
708 converted.data = TestBuffer(numberOfElements * sizeof(float));
709 const T* data = op.data.get<T>();
710 float* floatData = converted.data.getMutable<float>();
712 if (op.scale != 0.0f) {
713 std::transform(data, data + numberOfElements, floatData, [&op](T val) {
714 return (static_cast<float>(val) - op.zeroPoint) * op.scale;
715 });
716 } else {
717 std::transform(data, data + numberOfElements, floatData,
718 [](T val) { return static_cast<float>(val); });
719 }
720 return converted;
721 }
convertToFloat32Model(const TestModel & testModel)723 std::optional<TestModel> convertToFloat32Model(const TestModel& testModel) {
724 // Only single-operation graphs are supported.
725 if (testModel.referenced.size() > 0 || testModel.main.operations.size() > 1) {
726 return std::nullopt;
727 }
729 // Check for unsupported operations.
730 CHECK(!testModel.main.operations.empty());
731 const auto& operation = testModel.main.operations[0];
732 // Do not convert type-casting operations.
733 if (operation.type == TestOperationType::DEQUANTIZE ||
734 operation.type == TestOperationType::QUANTIZE ||
735 operation.type == TestOperationType::CAST) {
736 return std::nullopt;
737 }
738 // HASHTABLE_LOOKUP has different behavior in float and quant data types: float
739 // HASHTABLE_LOOKUP will output logical zero when there is a key miss, while quant
740 // HASHTABLE_LOOKUP will output byte zero.
741 if (operation.type == TestOperationType::HASHTABLE_LOOKUP) {
742 return std::nullopt;
743 }
745 auto convert = [&testModel, &operation](const TestOperand& op, uint32_t index) {
746 switch (op.type) {
747 case TestOperandType::TENSOR_FLOAT32:
748 case TestOperandType::FLOAT32:
749 case TestOperandType::TENSOR_BOOL8:
750 case TestOperandType::BOOL:
751 case TestOperandType::UINT32:
752 return op;
753 case TestOperandType::INT32:
754 // The third input of PAD_V2 uses INT32 to specify the padded value.
755 if (operation.type == TestOperationType::PAD_V2 && index == operation.inputs[2]) {
756 // The scale and zero point is inherited from the first input.
757 const uint32_t input0Index = operation.inputs[0];
758 const auto& input0 = testModel.main.operands[input0Index];
759 TestOperand scalarWithScaleAndZeroPoint = op;
760 scalarWithScaleAndZeroPoint.scale = input0.scale;
761 scalarWithScaleAndZeroPoint.zeroPoint = input0.zeroPoint;
762 return convertOperandToFloat32<int32_t>(scalarWithScaleAndZeroPoint);
763 }
764 return op;
765 case TestOperandType::TENSOR_INT32:
766 if (op.scale != 0.0f || op.zeroPoint != 0) {
767 return convertOperandToFloat32<int32_t>(op);
768 }
769 return op;
770 case TestOperandType::TENSOR_FLOAT16:
771 case TestOperandType::FLOAT16:
772 return convertOperandToFloat32<_Float16>(op);
773 case TestOperandType::TENSOR_QUANT8_ASYMM:
774 return convertOperandToFloat32<uint8_t>(op);
775 case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
776 return convertOperandToFloat32<int8_t>(op);
777 case TestOperandType::TENSOR_QUANT16_ASYMM:
778 return convertOperandToFloat32<uint16_t>(op);
779 case TestOperandType::TENSOR_QUANT16_SYMM:
780 return convertOperandToFloat32<int16_t>(op);
781 default:
782 CHECK(false) << "OperandType not supported";
783 return TestOperand{};
784 }
785 };
787 TestModel converted = testModel;
788 for (uint32_t i = 0; i < testModel.main.operands.size(); i++) {
789 converted.main.operands[i] = convert(testModel.main.operands[i], i);
790 }
791 return converted;
792 }
794 template <typename T>
setDataFromFloat32Buffer(const TestBuffer & fpBuffer,TestOperand * op)795 static void setDataFromFloat32Buffer(const TestBuffer& fpBuffer, TestOperand* op) {
796 const uint32_t numberOfElements = getNumberOfElements(*op);
797 const float* floatData = fpBuffer.get<float>();
798 T* data = op->data.getMutable<T>();
800 if (op->scale != 0.0f) {
801 std::transform(floatData, floatData + numberOfElements, data, [op](float val) {
802 int32_t unclamped = std::round(val / op->scale) + op->zeroPoint;
803 int32_t clamped = std::clamp<int32_t>(unclamped, std::numeric_limits<T>::min(),
804 std::numeric_limits<T>::max());
805 return static_cast<T>(clamped);
806 });
807 } else {
808 std::transform(floatData, floatData + numberOfElements, data,
809 [](float val) { return static_cast<T>(val); });
810 }
811 }
setExpectedOutputsFromFloat32Results(const std::vector<TestBuffer> & results,TestModel * model)813 void setExpectedOutputsFromFloat32Results(const std::vector<TestBuffer>& results,
814 TestModel* model) {
815 CHECK_EQ(model->referenced.size(), 0u) << "Subgraphs not supported";
816 CHECK_EQ(model->main.operations.size(), 1u) << "Only single-operation graph is supported";
818 for (uint32_t i = 0; i < results.size(); i++) {
819 uint32_t outputIndex = model->main.outputIndexes[i];
820 auto& op = model->main.operands[outputIndex];
821 switch (op.type) {
822 case TestOperandType::TENSOR_FLOAT32:
823 case TestOperandType::FLOAT32:
824 case TestOperandType::TENSOR_BOOL8:
825 case TestOperandType::BOOL:
826 case TestOperandType::INT32:
827 case TestOperandType::UINT32:
828 op.data = results[i];
829 break;
830 case TestOperandType::TENSOR_INT32:
831 if (op.scale != 0.0f) {
832 setDataFromFloat32Buffer<int32_t>(results[i], &op);
833 } else {
834 op.data = results[i];
835 }
836 break;
837 case TestOperandType::TENSOR_FLOAT16:
838 case TestOperandType::FLOAT16:
839 setDataFromFloat32Buffer<_Float16>(results[i], &op);
840 break;
841 case TestOperandType::TENSOR_QUANT8_ASYMM:
842 setDataFromFloat32Buffer<uint8_t>(results[i], &op);
843 break;
844 case TestOperandType::TENSOR_QUANT8_ASYMM_SIGNED:
845 setDataFromFloat32Buffer<int8_t>(results[i], &op);
846 break;
847 case TestOperandType::TENSOR_QUANT16_ASYMM:
848 setDataFromFloat32Buffer<uint16_t>(results[i], &op);
849 break;
850 case TestOperandType::TENSOR_QUANT16_SYMM:
851 setDataFromFloat32Buffer<int16_t>(results[i], &op);
852 break;
853 default:
854 CHECK(false) << "OperandType not supported";
855 }
856 }
857 }
859 } // namespace test_helper