1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19 #include <android-base/logging.h>
20 #include <fcntl.h>
21 #include <ftw.h>
22 #include <gtest/gtest.h>
23 #include <hidlmemory/mapping.h>
24 #include <unistd.h>
25
26 #include <cstdio>
27 #include <cstdlib>
28 #include <random>
29 #include <thread>
30
31 #include "1.3/Callbacks.h"
32 #include "1.3/Utils.h"
33 #include "GeneratedTestHarness.h"
34 #include "MemoryUtils.h"
35 #include "TestHarness.h"
36 #include "Utils.h"
37 #include "VtsHalNeuralnetworks.h"
38
39 // Forward declaration of the mobilenet generated test models in
40 // frameworks/ml/nn/runtime/test/generated/.
41 namespace generated_tests::mobilenet_224_gender_basic_fixed {
42 const test_helper::TestModel& get_test_model();
43 } // namespace generated_tests::mobilenet_224_gender_basic_fixed
44
45 namespace generated_tests::mobilenet_quantized {
46 const test_helper::TestModel& get_test_model();
47 } // namespace generated_tests::mobilenet_quantized
48
49 namespace android::hardware::neuralnetworks::V1_3::vts::functional {
50
51 using namespace test_helper;
52 using implementation::PreparedModelCallback;
53 using V1_1::ExecutionPreference;
54 using V1_2::Constant;
55 using V1_2::OperationType;
56
57 namespace float32_model {
58
59 constexpr auto get_test_model = generated_tests::mobilenet_224_gender_basic_fixed::get_test_model;
60
61 } // namespace float32_model
62
63 namespace quant8_model {
64
65 constexpr auto get_test_model = generated_tests::mobilenet_quantized::get_test_model;
66
67 } // namespace quant8_model
68
69 namespace {
70
71 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
72
73 // Creates cache handles based on provided file groups.
74 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,const std::vector<AccessMode> & mode,hidl_vec<hidl_handle> * handles)75 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
76 const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
77 handles->resize(fileGroups.size());
78 for (uint32_t i = 0; i < fileGroups.size(); i++) {
79 std::vector<int> fds;
80 for (const auto& file : fileGroups[i]) {
81 int fd;
82 if (mode[i] == AccessMode::READ_ONLY) {
83 fd = open(file.c_str(), O_RDONLY);
84 } else if (mode[i] == AccessMode::WRITE_ONLY) {
85 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
86 } else if (mode[i] == AccessMode::READ_WRITE) {
87 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
88 } else {
89 FAIL();
90 }
91 ASSERT_GE(fd, 0);
92 fds.push_back(fd);
93 }
94 native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
95 ASSERT_NE(cacheNativeHandle, nullptr);
96 std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
97 (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
98 }
99 }
100
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,AccessMode mode,hidl_vec<hidl_handle> * handles)101 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
102 hidl_vec<hidl_handle>* handles) {
103 createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
104 }
105
106 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
107 // For simplicity, activation scalar is shared. The second operand is not shared
108 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
109 // data locations in cache.
110 //
111 // --------- activation --------
112 // ↓ ↓ ↓ ↓
113 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
114 // ↑ ↑ ↑ ↑
115 // [1] [1] [1] [1]
116 //
117 // This function assumes the operation is either ADD or MUL.
118 template <typename CppType, TestOperandType operandType>
createLargeTestModelImpl(TestOperationType op,uint32_t len)119 TestModel createLargeTestModelImpl(TestOperationType op, uint32_t len) {
120 EXPECT_TRUE(op == TestOperationType::ADD || op == TestOperationType::MUL);
121
122 // Model operations and operands.
123 std::vector<TestOperation> operations(len);
124 std::vector<TestOperand> operands(len * 2 + 2);
125
126 // The activation scalar, value = 0.
127 operands[0] = {
128 .type = TestOperandType::INT32,
129 .dimensions = {},
130 .numberOfConsumers = len,
131 .scale = 0.0f,
132 .zeroPoint = 0,
133 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
134 .data = TestBuffer::createFromVector<int32_t>({0}),
135 };
136
137 // The buffer value of the constant second operand. The logical value is always 1.0f.
138 CppType bufferValue;
139 // The scale of the first and second operand.
140 float scale1, scale2;
141 if (operandType == TestOperandType::TENSOR_FLOAT32) {
142 bufferValue = 1.0f;
143 scale1 = 0.0f;
144 scale2 = 0.0f;
145 } else if (op == TestOperationType::ADD) {
146 bufferValue = 1;
147 scale1 = 1.0f;
148 scale2 = 1.0f;
149 } else {
150 // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
151 // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
152 bufferValue = 2;
153 scale1 = 1.0f;
154 scale2 = 0.5f;
155 }
156
157 for (uint32_t i = 0; i < len; i++) {
158 const uint32_t firstInputIndex = i * 2 + 1;
159 const uint32_t secondInputIndex = firstInputIndex + 1;
160 const uint32_t outputIndex = secondInputIndex + 1;
161
162 // The first operation input.
163 operands[firstInputIndex] = {
164 .type = operandType,
165 .dimensions = {1},
166 .numberOfConsumers = 1,
167 .scale = scale1,
168 .zeroPoint = 0,
169 .lifetime = (i == 0 ? TestOperandLifeTime::MODEL_INPUT
170 : TestOperandLifeTime::TEMPORARY_VARIABLE),
171 .data = (i == 0 ? TestBuffer::createFromVector<CppType>({1}) : TestBuffer()),
172 };
173
174 // The second operation input, value = 1.
175 operands[secondInputIndex] = {
176 .type = operandType,
177 .dimensions = {1},
178 .numberOfConsumers = 1,
179 .scale = scale2,
180 .zeroPoint = 0,
181 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
182 .data = TestBuffer::createFromVector<CppType>({bufferValue}),
183 };
184
185 // The operation. All operations share the same activation scalar.
186 // The output operand is created as an input in the next iteration of the loop, in the case
187 // of all but the last member of the chain; and after the loop as a model output, in the
188 // case of the last member of the chain.
189 operations[i] = {
190 .type = op,
191 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
192 .outputs = {outputIndex},
193 };
194 }
195
196 // For TestOperationType::ADD, output = 1 + 1 * len = len + 1
197 // For TestOperationType::MUL, output = 1 * 1 ^ len = 1
198 CppType outputResult = static_cast<CppType>(op == TestOperationType::ADD ? len + 1u : 1u);
199
200 // The model output.
201 operands.back() = {
202 .type = operandType,
203 .dimensions = {1},
204 .numberOfConsumers = 0,
205 .scale = scale1,
206 .zeroPoint = 0,
207 .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
208 .data = TestBuffer::createFromVector<CppType>({outputResult}),
209 };
210
211 return {
212 .main = {.operands = std::move(operands),
213 .operations = std::move(operations),
214 .inputIndexes = {1},
215 .outputIndexes = {len * 2 + 1}},
216 .isRelaxed = false,
217 };
218 }
219
220 } // namespace
221
222 // Tag for the compilation caching tests.
223 class CompilationCachingTestBase : public testing::Test {
224 protected:
CompilationCachingTestBase(sp<IDevice> device,OperandType type)225 CompilationCachingTestBase(sp<IDevice> device, OperandType type)
226 : kDevice(std::move(device)), kOperandType(type) {}
227
SetUp()228 void SetUp() override {
229 testing::Test::SetUp();
230 ASSERT_NE(kDevice.get(), nullptr);
231
232 // Create cache directory. The cache directory and a temporary cache file is always created
233 // to test the behavior of prepareModelFromCache_1_3, even when caching is not supported.
234 char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
235 char* cacheDir = mkdtemp(cacheDirTemp);
236 ASSERT_NE(cacheDir, nullptr);
237 mCacheDir = cacheDir;
238 mCacheDir.push_back('/');
239
240 Return<void> ret = kDevice->getNumberOfCacheFilesNeeded(
241 [this](V1_0::ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
242 EXPECT_EQ(V1_0::ErrorStatus::NONE, status);
243 mNumModelCache = numModelCache;
244 mNumDataCache = numDataCache;
245 });
246 EXPECT_TRUE(ret.isOk());
247 mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
248
249 // Create empty cache files.
250 mTmpCache = mCacheDir + "tmp";
251 for (uint32_t i = 0; i < mNumModelCache; i++) {
252 mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
253 }
254 for (uint32_t i = 0; i < mNumDataCache; i++) {
255 mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
256 }
257 // Dummy handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
258 hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
259 createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
260 createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
261 createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
262
263 if (!mIsCachingSupported) {
264 LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
265 "support compilation caching.";
266 std::cout << "[ ] Early termination of test because vendor service does not "
267 "support compilation caching."
268 << std::endl;
269 }
270 }
271
TearDown()272 void TearDown() override {
273 // If the test passes, remove the tmp directory. Otherwise, keep it for debugging purposes.
274 if (!testing::Test::HasFailure()) {
275 // Recursively remove the cache directory specified by mCacheDir.
276 auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
277 return remove(entry);
278 };
279 nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
280 }
281 testing::Test::TearDown();
282 }
283
284 // Model and examples creators. According to kOperandType, the following methods will return
285 // either float32 model/examples or the quant8 variant.
createTestModel()286 TestModel createTestModel() {
287 if (kOperandType == OperandType::TENSOR_FLOAT32) {
288 return float32_model::get_test_model();
289 } else {
290 return quant8_model::get_test_model();
291 }
292 }
293
createLargeTestModel(OperationType op,uint32_t len)294 TestModel createLargeTestModel(OperationType op, uint32_t len) {
295 if (kOperandType == OperandType::TENSOR_FLOAT32) {
296 return createLargeTestModelImpl<float, TestOperandType::TENSOR_FLOAT32>(
297 static_cast<TestOperationType>(op), len);
298 } else {
299 return createLargeTestModelImpl<uint8_t, TestOperandType::TENSOR_QUANT8_ASYMM>(
300 static_cast<TestOperationType>(op), len);
301 }
302 }
303
304 // See if the service can handle the model.
isModelFullySupported(const Model & model)305 bool isModelFullySupported(const Model& model) {
306 bool fullySupportsModel = false;
307 Return<void> supportedCall = kDevice->getSupportedOperations_1_3(
308 model,
309 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
310 ASSERT_EQ(ErrorStatus::NONE, status);
311 ASSERT_EQ(supported.size(), model.main.operations.size());
312 fullySupportsModel = std::all_of(supported.begin(), supported.end(),
313 [](bool valid) { return valid; });
314 });
315 EXPECT_TRUE(supportedCall.isOk());
316 return fullySupportsModel;
317 }
318
saveModelToCache(const Model & model,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel=nullptr)319 void saveModelToCache(const Model& model, const hidl_vec<hidl_handle>& modelCache,
320 const hidl_vec<hidl_handle>& dataCache,
321 sp<IPreparedModel>* preparedModel = nullptr) {
322 if (preparedModel != nullptr) *preparedModel = nullptr;
323
324 // Launch prepare model.
325 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
326 hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
327 Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModel_1_3(
328 model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, {}, modelCache,
329 dataCache, cacheToken, preparedModelCallback);
330 ASSERT_TRUE(prepareLaunchStatus.isOk());
331 ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
332
333 // Retrieve prepared model.
334 preparedModelCallback->wait();
335 ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
336 if (preparedModel != nullptr) {
337 *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
338 .withDefault(nullptr);
339 }
340 }
341
checkEarlyTermination(ErrorStatus status)342 bool checkEarlyTermination(ErrorStatus status) {
343 if (status == ErrorStatus::GENERAL_FAILURE) {
344 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
345 "save the prepared model that it does not support.";
346 std::cout << "[ ] Early termination of test because vendor service cannot "
347 "save the prepared model that it does not support."
348 << std::endl;
349 return true;
350 }
351 return false;
352 }
353
checkEarlyTermination(const Model & model)354 bool checkEarlyTermination(const Model& model) {
355 if (!isModelFullySupported(model)) {
356 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
357 "prepare model that it does not support.";
358 std::cout << "[ ] Early termination of test because vendor service cannot "
359 "prepare model that it does not support."
360 << std::endl;
361 return true;
362 }
363 return false;
364 }
365
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel,ErrorStatus * status)366 void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
367 const hidl_vec<hidl_handle>& dataCache,
368 sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
369 // Launch prepare model from cache.
370 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
371 hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
372 Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModelFromCache_1_3(
373 {}, modelCache, dataCache, cacheToken, preparedModelCallback);
374 ASSERT_TRUE(prepareLaunchStatus.isOk());
375 if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
376 *preparedModel = nullptr;
377 *status = static_cast<ErrorStatus>(prepareLaunchStatus);
378 return;
379 }
380
381 // Retrieve prepared model.
382 preparedModelCallback->wait();
383 *status = preparedModelCallback->getStatus();
384 *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
385 .withDefault(nullptr);
386 }
387
388 // Absolute path to the temporary cache directory.
389 std::string mCacheDir;
390
391 // Groups of file paths for model and data cache in the tmp cache directory, initialized with
392 // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
393 // and the inner vector is for fds held by each handle.
394 std::vector<std::vector<std::string>> mModelCache;
395 std::vector<std::vector<std::string>> mDataCache;
396
397 // A separate temporary file path in the tmp cache directory.
398 std::string mTmpCache;
399
400 uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
401 uint32_t mNumModelCache;
402 uint32_t mNumDataCache;
403 uint32_t mIsCachingSupported;
404
405 const sp<IDevice> kDevice;
406 // The primary data type of the testModel.
407 const OperandType kOperandType;
408 };
409
410 using CompilationCachingTestParam = std::tuple<NamedDevice, OperandType>;
411
412 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
413 // pass running with float32 models and the second pass running with quant8 models.
414 class CompilationCachingTest : public CompilationCachingTestBase,
415 public testing::WithParamInterface<CompilationCachingTestParam> {
416 protected:
CompilationCachingTest()417 CompilationCachingTest()
418 : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
419 std::get<OperandType>(GetParam())) {}
420 };
421
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)422 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
423 // Create test HIDL model and compile.
424 const TestModel& testModel = createTestModel();
425 const Model model = createModel(testModel);
426 if (checkEarlyTermination(model)) return;
427 sp<IPreparedModel> preparedModel = nullptr;
428
429 // Save the compilation to cache.
430 {
431 hidl_vec<hidl_handle> modelCache, dataCache;
432 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
433 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
434 saveModelToCache(model, modelCache, dataCache);
435 }
436
437 // Retrieve preparedModel from cache.
438 {
439 preparedModel = nullptr;
440 ErrorStatus status;
441 hidl_vec<hidl_handle> modelCache, dataCache;
442 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
443 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
444 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
445 if (!mIsCachingSupported) {
446 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
447 ASSERT_EQ(preparedModel, nullptr);
448 return;
449 } else if (checkEarlyTermination(status)) {
450 ASSERT_EQ(preparedModel, nullptr);
451 return;
452 } else {
453 ASSERT_EQ(status, ErrorStatus::NONE);
454 ASSERT_NE(preparedModel, nullptr);
455 }
456 }
457
458 // Execute and verify results.
459 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
460 }
461
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)462 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
463 // Create test HIDL model and compile.
464 const TestModel& testModel = createTestModel();
465 const Model model = createModel(testModel);
466 if (checkEarlyTermination(model)) return;
467 sp<IPreparedModel> preparedModel = nullptr;
468
469 // Save the compilation to cache.
470 {
471 hidl_vec<hidl_handle> modelCache, dataCache;
472 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
473 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
474 uint8_t dummyBytes[] = {0, 0};
475 // Write a dummy integer to the cache.
476 // The driver should be able to handle non-empty cache and non-zero fd offset.
477 for (uint32_t i = 0; i < modelCache.size(); i++) {
478 ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &dummyBytes,
479 sizeof(dummyBytes)),
480 sizeof(dummyBytes));
481 }
482 for (uint32_t i = 0; i < dataCache.size(); i++) {
483 ASSERT_EQ(
484 write(dataCache[i].getNativeHandle()->data[0], &dummyBytes, sizeof(dummyBytes)),
485 sizeof(dummyBytes));
486 }
487 saveModelToCache(model, modelCache, dataCache);
488 }
489
490 // Retrieve preparedModel from cache.
491 {
492 preparedModel = nullptr;
493 ErrorStatus status;
494 hidl_vec<hidl_handle> modelCache, dataCache;
495 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
496 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
497 uint8_t dummyByte = 0;
498 // Advance the offset of each handle by one byte.
499 // The driver should be able to handle non-zero fd offset.
500 for (uint32_t i = 0; i < modelCache.size(); i++) {
501 ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
502 }
503 for (uint32_t i = 0; i < dataCache.size(); i++) {
504 ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
505 }
506 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
507 if (!mIsCachingSupported) {
508 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
509 ASSERT_EQ(preparedModel, nullptr);
510 return;
511 } else if (checkEarlyTermination(status)) {
512 ASSERT_EQ(preparedModel, nullptr);
513 return;
514 } else {
515 ASSERT_EQ(status, ErrorStatus::NONE);
516 ASSERT_NE(preparedModel, nullptr);
517 }
518 }
519
520 // Execute and verify results.
521 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
522 }
523
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)524 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
525 // Create test HIDL model and compile.
526 const TestModel& testModel = createTestModel();
527 const Model model = createModel(testModel);
528 if (checkEarlyTermination(model)) return;
529
530 // Test with number of model cache files greater than mNumModelCache.
531 {
532 hidl_vec<hidl_handle> modelCache, dataCache;
533 // Pass an additional cache file for model cache.
534 mModelCache.push_back({mTmpCache});
535 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
536 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
537 mModelCache.pop_back();
538 sp<IPreparedModel> preparedModel = nullptr;
539 saveModelToCache(model, modelCache, dataCache, &preparedModel);
540 ASSERT_NE(preparedModel, nullptr);
541 // Execute and verify results.
542 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
543 // Check if prepareModelFromCache fails.
544 preparedModel = nullptr;
545 ErrorStatus status;
546 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
547 if (status != ErrorStatus::INVALID_ARGUMENT) {
548 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
549 }
550 ASSERT_EQ(preparedModel, nullptr);
551 }
552
553 // Test with number of model cache files smaller than mNumModelCache.
554 if (mModelCache.size() > 0) {
555 hidl_vec<hidl_handle> modelCache, dataCache;
556 // Pop out the last cache file.
557 auto tmp = mModelCache.back();
558 mModelCache.pop_back();
559 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
560 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
561 mModelCache.push_back(tmp);
562 sp<IPreparedModel> preparedModel = nullptr;
563 saveModelToCache(model, modelCache, dataCache, &preparedModel);
564 ASSERT_NE(preparedModel, nullptr);
565 // Execute and verify results.
566 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
567 // Check if prepareModelFromCache fails.
568 preparedModel = nullptr;
569 ErrorStatus status;
570 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
571 if (status != ErrorStatus::INVALID_ARGUMENT) {
572 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
573 }
574 ASSERT_EQ(preparedModel, nullptr);
575 }
576
577 // Test with number of data cache files greater than mNumDataCache.
578 {
579 hidl_vec<hidl_handle> modelCache, dataCache;
580 // Pass an additional cache file for data cache.
581 mDataCache.push_back({mTmpCache});
582 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
583 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
584 mDataCache.pop_back();
585 sp<IPreparedModel> preparedModel = nullptr;
586 saveModelToCache(model, modelCache, dataCache, &preparedModel);
587 ASSERT_NE(preparedModel, nullptr);
588 // Execute and verify results.
589 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
590 // Check if prepareModelFromCache fails.
591 preparedModel = nullptr;
592 ErrorStatus status;
593 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
594 if (status != ErrorStatus::INVALID_ARGUMENT) {
595 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
596 }
597 ASSERT_EQ(preparedModel, nullptr);
598 }
599
600 // Test with number of data cache files smaller than mNumDataCache.
601 if (mDataCache.size() > 0) {
602 hidl_vec<hidl_handle> modelCache, dataCache;
603 // Pop out the last cache file.
604 auto tmp = mDataCache.back();
605 mDataCache.pop_back();
606 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
607 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
608 mDataCache.push_back(tmp);
609 sp<IPreparedModel> preparedModel = nullptr;
610 saveModelToCache(model, modelCache, dataCache, &preparedModel);
611 ASSERT_NE(preparedModel, nullptr);
612 // Execute and verify results.
613 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
614 // Check if prepareModelFromCache fails.
615 preparedModel = nullptr;
616 ErrorStatus status;
617 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
618 if (status != ErrorStatus::INVALID_ARGUMENT) {
619 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
620 }
621 ASSERT_EQ(preparedModel, nullptr);
622 }
623 }
624
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)625 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
626 // Create test HIDL model and compile.
627 const TestModel& testModel = createTestModel();
628 const Model model = createModel(testModel);
629 if (checkEarlyTermination(model)) return;
630
631 // Save the compilation to cache.
632 {
633 hidl_vec<hidl_handle> modelCache, dataCache;
634 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
635 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
636 saveModelToCache(model, modelCache, dataCache);
637 }
638
639 // Test with number of model cache files greater than mNumModelCache.
640 {
641 sp<IPreparedModel> preparedModel = nullptr;
642 ErrorStatus status;
643 hidl_vec<hidl_handle> modelCache, dataCache;
644 mModelCache.push_back({mTmpCache});
645 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
646 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
647 mModelCache.pop_back();
648 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
649 if (status != ErrorStatus::GENERAL_FAILURE) {
650 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
651 }
652 ASSERT_EQ(preparedModel, nullptr);
653 }
654
655 // Test with number of model cache files smaller than mNumModelCache.
656 if (mModelCache.size() > 0) {
657 sp<IPreparedModel> preparedModel = nullptr;
658 ErrorStatus status;
659 hidl_vec<hidl_handle> modelCache, dataCache;
660 auto tmp = mModelCache.back();
661 mModelCache.pop_back();
662 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
663 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
664 mModelCache.push_back(tmp);
665 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
666 if (status != ErrorStatus::GENERAL_FAILURE) {
667 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
668 }
669 ASSERT_EQ(preparedModel, nullptr);
670 }
671
672 // Test with number of data cache files greater than mNumDataCache.
673 {
674 sp<IPreparedModel> preparedModel = nullptr;
675 ErrorStatus status;
676 hidl_vec<hidl_handle> modelCache, dataCache;
677 mDataCache.push_back({mTmpCache});
678 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
679 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
680 mDataCache.pop_back();
681 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
682 if (status != ErrorStatus::GENERAL_FAILURE) {
683 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
684 }
685 ASSERT_EQ(preparedModel, nullptr);
686 }
687
688 // Test with number of data cache files smaller than mNumDataCache.
689 if (mDataCache.size() > 0) {
690 sp<IPreparedModel> preparedModel = nullptr;
691 ErrorStatus status;
692 hidl_vec<hidl_handle> modelCache, dataCache;
693 auto tmp = mDataCache.back();
694 mDataCache.pop_back();
695 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
696 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
697 mDataCache.push_back(tmp);
698 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
699 if (status != ErrorStatus::GENERAL_FAILURE) {
700 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
701 }
702 ASSERT_EQ(preparedModel, nullptr);
703 }
704 }
705
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumFd)706 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
707 // Create test HIDL model and compile.
708 const TestModel& testModel = createTestModel();
709 const Model model = createModel(testModel);
710 if (checkEarlyTermination(model)) return;
711
712 // Go through each handle in model cache, test with NumFd greater than 1.
713 for (uint32_t i = 0; i < mNumModelCache; i++) {
714 hidl_vec<hidl_handle> modelCache, dataCache;
715 // Pass an invalid number of fds for handle i.
716 mModelCache[i].push_back(mTmpCache);
717 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
718 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
719 mModelCache[i].pop_back();
720 sp<IPreparedModel> preparedModel = nullptr;
721 saveModelToCache(model, modelCache, dataCache, &preparedModel);
722 ASSERT_NE(preparedModel, nullptr);
723 // Execute and verify results.
724 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
725 // Check if prepareModelFromCache fails.
726 preparedModel = nullptr;
727 ErrorStatus status;
728 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
729 if (status != ErrorStatus::INVALID_ARGUMENT) {
730 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
731 }
732 ASSERT_EQ(preparedModel, nullptr);
733 }
734
735 // Go through each handle in model cache, test with NumFd equal to 0.
736 for (uint32_t i = 0; i < mNumModelCache; i++) {
737 hidl_vec<hidl_handle> modelCache, dataCache;
738 // Pass an invalid number of fds for handle i.
739 auto tmp = mModelCache[i].back();
740 mModelCache[i].pop_back();
741 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
742 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
743 mModelCache[i].push_back(tmp);
744 sp<IPreparedModel> preparedModel = nullptr;
745 saveModelToCache(model, modelCache, dataCache, &preparedModel);
746 ASSERT_NE(preparedModel, nullptr);
747 // Execute and verify results.
748 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
749 // Check if prepareModelFromCache fails.
750 preparedModel = nullptr;
751 ErrorStatus status;
752 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
753 if (status != ErrorStatus::INVALID_ARGUMENT) {
754 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
755 }
756 ASSERT_EQ(preparedModel, nullptr);
757 }
758
759 // Go through each handle in data cache, test with NumFd greater than 1.
760 for (uint32_t i = 0; i < mNumDataCache; i++) {
761 hidl_vec<hidl_handle> modelCache, dataCache;
762 // Pass an invalid number of fds for handle i.
763 mDataCache[i].push_back(mTmpCache);
764 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
765 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
766 mDataCache[i].pop_back();
767 sp<IPreparedModel> preparedModel = nullptr;
768 saveModelToCache(model, modelCache, dataCache, &preparedModel);
769 ASSERT_NE(preparedModel, nullptr);
770 // Execute and verify results.
771 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
772 // Check if prepareModelFromCache fails.
773 preparedModel = nullptr;
774 ErrorStatus status;
775 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
776 if (status != ErrorStatus::INVALID_ARGUMENT) {
777 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
778 }
779 ASSERT_EQ(preparedModel, nullptr);
780 }
781
782 // Go through each handle in data cache, test with NumFd equal to 0.
783 for (uint32_t i = 0; i < mNumDataCache; i++) {
784 hidl_vec<hidl_handle> modelCache, dataCache;
785 // Pass an invalid number of fds for handle i.
786 auto tmp = mDataCache[i].back();
787 mDataCache[i].pop_back();
788 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
789 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
790 mDataCache[i].push_back(tmp);
791 sp<IPreparedModel> preparedModel = nullptr;
792 saveModelToCache(model, modelCache, dataCache, &preparedModel);
793 ASSERT_NE(preparedModel, nullptr);
794 // Execute and verify results.
795 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
796 // Check if prepareModelFromCache fails.
797 preparedModel = nullptr;
798 ErrorStatus status;
799 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
800 if (status != ErrorStatus::INVALID_ARGUMENT) {
801 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
802 }
803 ASSERT_EQ(preparedModel, nullptr);
804 }
805 }
806
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumFd)807 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
808 // Create test HIDL model and compile.
809 const TestModel& testModel = createTestModel();
810 const Model model = createModel(testModel);
811 if (checkEarlyTermination(model)) return;
812
813 // Save the compilation to cache.
814 {
815 hidl_vec<hidl_handle> modelCache, dataCache;
816 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
817 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
818 saveModelToCache(model, modelCache, dataCache);
819 }
820
821 // Go through each handle in model cache, test with NumFd greater than 1.
822 for (uint32_t i = 0; i < mNumModelCache; i++) {
823 sp<IPreparedModel> preparedModel = nullptr;
824 ErrorStatus status;
825 hidl_vec<hidl_handle> modelCache, dataCache;
826 mModelCache[i].push_back(mTmpCache);
827 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
828 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
829 mModelCache[i].pop_back();
830 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
831 if (status != ErrorStatus::GENERAL_FAILURE) {
832 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
833 }
834 ASSERT_EQ(preparedModel, nullptr);
835 }
836
837 // Go through each handle in model cache, test with NumFd equal to 0.
838 for (uint32_t i = 0; i < mNumModelCache; i++) {
839 sp<IPreparedModel> preparedModel = nullptr;
840 ErrorStatus status;
841 hidl_vec<hidl_handle> modelCache, dataCache;
842 auto tmp = mModelCache[i].back();
843 mModelCache[i].pop_back();
844 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
845 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
846 mModelCache[i].push_back(tmp);
847 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
848 if (status != ErrorStatus::GENERAL_FAILURE) {
849 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
850 }
851 ASSERT_EQ(preparedModel, nullptr);
852 }
853
854 // Go through each handle in data cache, test with NumFd greater than 1.
855 for (uint32_t i = 0; i < mNumDataCache; i++) {
856 sp<IPreparedModel> preparedModel = nullptr;
857 ErrorStatus status;
858 hidl_vec<hidl_handle> modelCache, dataCache;
859 mDataCache[i].push_back(mTmpCache);
860 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
861 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
862 mDataCache[i].pop_back();
863 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
864 if (status != ErrorStatus::GENERAL_FAILURE) {
865 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
866 }
867 ASSERT_EQ(preparedModel, nullptr);
868 }
869
870 // Go through each handle in data cache, test with NumFd equal to 0.
871 for (uint32_t i = 0; i < mNumDataCache; i++) {
872 sp<IPreparedModel> preparedModel = nullptr;
873 ErrorStatus status;
874 hidl_vec<hidl_handle> modelCache, dataCache;
875 auto tmp = mDataCache[i].back();
876 mDataCache[i].pop_back();
877 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
878 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
879 mDataCache[i].push_back(tmp);
880 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
881 if (status != ErrorStatus::GENERAL_FAILURE) {
882 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
883 }
884 ASSERT_EQ(preparedModel, nullptr);
885 }
886 }
887
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)888 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
889 // Create test HIDL model and compile.
890 const TestModel& testModel = createTestModel();
891 const Model model = createModel(testModel);
892 if (checkEarlyTermination(model)) return;
893 std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
894 std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
895
896 // Go through each handle in model cache, test with invalid access mode.
897 for (uint32_t i = 0; i < mNumModelCache; i++) {
898 hidl_vec<hidl_handle> modelCache, dataCache;
899 modelCacheMode[i] = AccessMode::READ_ONLY;
900 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
901 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
902 modelCacheMode[i] = AccessMode::READ_WRITE;
903 sp<IPreparedModel> preparedModel = nullptr;
904 saveModelToCache(model, modelCache, dataCache, &preparedModel);
905 ASSERT_NE(preparedModel, nullptr);
906 // Execute and verify results.
907 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
908 // Check if prepareModelFromCache fails.
909 preparedModel = nullptr;
910 ErrorStatus status;
911 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
912 if (status != ErrorStatus::INVALID_ARGUMENT) {
913 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
914 }
915 ASSERT_EQ(preparedModel, nullptr);
916 }
917
918 // Go through each handle in data cache, test with invalid access mode.
919 for (uint32_t i = 0; i < mNumDataCache; i++) {
920 hidl_vec<hidl_handle> modelCache, dataCache;
921 dataCacheMode[i] = AccessMode::READ_ONLY;
922 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
923 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
924 dataCacheMode[i] = AccessMode::READ_WRITE;
925 sp<IPreparedModel> preparedModel = nullptr;
926 saveModelToCache(model, modelCache, dataCache, &preparedModel);
927 ASSERT_NE(preparedModel, nullptr);
928 // Execute and verify results.
929 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
930 // Check if prepareModelFromCache fails.
931 preparedModel = nullptr;
932 ErrorStatus status;
933 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
934 if (status != ErrorStatus::INVALID_ARGUMENT) {
935 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
936 }
937 ASSERT_EQ(preparedModel, nullptr);
938 }
939 }
940
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)941 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
942 // Create test HIDL model and compile.
943 const TestModel& testModel = createTestModel();
944 const Model model = createModel(testModel);
945 if (checkEarlyTermination(model)) return;
946 std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
947 std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
948
949 // Save the compilation to cache.
950 {
951 hidl_vec<hidl_handle> modelCache, dataCache;
952 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
953 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
954 saveModelToCache(model, modelCache, dataCache);
955 }
956
957 // Go through each handle in model cache, test with invalid access mode.
958 for (uint32_t i = 0; i < mNumModelCache; i++) {
959 sp<IPreparedModel> preparedModel = nullptr;
960 ErrorStatus status;
961 hidl_vec<hidl_handle> modelCache, dataCache;
962 modelCacheMode[i] = AccessMode::WRITE_ONLY;
963 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
964 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
965 modelCacheMode[i] = AccessMode::READ_WRITE;
966 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
967 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
968 ASSERT_EQ(preparedModel, nullptr);
969 }
970
971 // Go through each handle in data cache, test with invalid access mode.
972 for (uint32_t i = 0; i < mNumDataCache; i++) {
973 sp<IPreparedModel> preparedModel = nullptr;
974 ErrorStatus status;
975 hidl_vec<hidl_handle> modelCache, dataCache;
976 dataCacheMode[i] = AccessMode::WRITE_ONLY;
977 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
978 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
979 dataCacheMode[i] = AccessMode::READ_WRITE;
980 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
981 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
982 ASSERT_EQ(preparedModel, nullptr);
983 }
984 }
985
986 // Copy file contents between file groups.
987 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
988 // The outer vector sizes must match and the inner vectors must have size = 1.
copyCacheFiles(const std::vector<std::vector<std::string>> & from,const std::vector<std::vector<std::string>> & to)989 static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
990 const std::vector<std::vector<std::string>>& to) {
991 constexpr size_t kBufferSize = 1000000;
992 uint8_t buffer[kBufferSize];
993
994 ASSERT_EQ(from.size(), to.size());
995 for (uint32_t i = 0; i < from.size(); i++) {
996 ASSERT_EQ(from[i].size(), 1u);
997 ASSERT_EQ(to[i].size(), 1u);
998 int fromFd = open(from[i][0].c_str(), O_RDONLY);
999 int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
1000 ASSERT_GE(fromFd, 0);
1001 ASSERT_GE(toFd, 0);
1002
1003 ssize_t readBytes;
1004 while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
1005 ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
1006 }
1007 ASSERT_GE(readBytes, 0);
1008
1009 close(fromFd);
1010 close(toFd);
1011 }
1012 }
1013
1014 // Number of operations in the large test model.
1015 constexpr uint32_t kLargeModelSize = 100;
1016 constexpr uint32_t kNumIterationsTOCTOU = 100;
1017
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)1018 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
1019 if (!mIsCachingSupported) return;
1020
1021 // Create test models and check if fully supported by the service.
1022 const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1023 const Model modelMul = createModel(testModelMul);
1024 if (checkEarlyTermination(modelMul)) return;
1025 const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1026 const Model modelAdd = createModel(testModelAdd);
1027 if (checkEarlyTermination(modelAdd)) return;
1028
1029 // Save the modelMul compilation to cache.
1030 auto modelCacheMul = mModelCache;
1031 for (auto& cache : modelCacheMul) {
1032 cache[0].append("_mul");
1033 }
1034 {
1035 hidl_vec<hidl_handle> modelCache, dataCache;
1036 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1037 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1038 saveModelToCache(modelMul, modelCache, dataCache);
1039 }
1040
1041 // Use a different token for modelAdd.
1042 mToken[0]++;
1043
1044 // This test is probabilistic, so we run it multiple times.
1045 for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1046 // Save the modelAdd compilation to cache.
1047 {
1048 hidl_vec<hidl_handle> modelCache, dataCache;
1049 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1050 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1051
1052 // Spawn a thread to copy the cache content concurrently while saving to cache.
1053 std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1054 saveModelToCache(modelAdd, modelCache, dataCache);
1055 thread.join();
1056 }
1057
1058 // Retrieve preparedModel from cache.
1059 {
1060 sp<IPreparedModel> preparedModel = nullptr;
1061 ErrorStatus status;
1062 hidl_vec<hidl_handle> modelCache, dataCache;
1063 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1064 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1065 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1066
1067 // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1068 // the prepared model must be executed with the correct result and not crash.
1069 if (status != ErrorStatus::NONE) {
1070 ASSERT_EQ(preparedModel, nullptr);
1071 } else {
1072 ASSERT_NE(preparedModel, nullptr);
1073 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
1074 /*testKind=*/TestKind::GENERAL);
1075 }
1076 }
1077 }
1078 }
1079
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)1080 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
1081 if (!mIsCachingSupported) return;
1082
1083 // Create test models and check if fully supported by the service.
1084 const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1085 const Model modelMul = createModel(testModelMul);
1086 if (checkEarlyTermination(modelMul)) return;
1087 const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1088 const Model modelAdd = createModel(testModelAdd);
1089 if (checkEarlyTermination(modelAdd)) return;
1090
1091 // Save the modelMul compilation to cache.
1092 auto modelCacheMul = mModelCache;
1093 for (auto& cache : modelCacheMul) {
1094 cache[0].append("_mul");
1095 }
1096 {
1097 hidl_vec<hidl_handle> modelCache, dataCache;
1098 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1099 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1100 saveModelToCache(modelMul, modelCache, dataCache);
1101 }
1102
1103 // Use a different token for modelAdd.
1104 mToken[0]++;
1105
1106 // This test is probabilistic, so we run it multiple times.
1107 for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1108 // Save the modelAdd compilation to cache.
1109 {
1110 hidl_vec<hidl_handle> modelCache, dataCache;
1111 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1112 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1113 saveModelToCache(modelAdd, modelCache, dataCache);
1114 }
1115
1116 // Retrieve preparedModel from cache.
1117 {
1118 sp<IPreparedModel> preparedModel = nullptr;
1119 ErrorStatus status;
1120 hidl_vec<hidl_handle> modelCache, dataCache;
1121 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1122 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1123
1124 // Spawn a thread to copy the cache content concurrently while preparing from cache.
1125 std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1126 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1127 thread.join();
1128
1129 // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1130 // the prepared model must be executed with the correct result and not crash.
1131 if (status != ErrorStatus::NONE) {
1132 ASSERT_EQ(preparedModel, nullptr);
1133 } else {
1134 ASSERT_NE(preparedModel, nullptr);
1135 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
1136 /*testKind=*/TestKind::GENERAL);
1137 }
1138 }
1139 }
1140 }
1141
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)1142 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
1143 if (!mIsCachingSupported) return;
1144
1145 // Create test models and check if fully supported by the service.
1146 const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1147 const Model modelMul = createModel(testModelMul);
1148 if (checkEarlyTermination(modelMul)) return;
1149 const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1150 const Model modelAdd = createModel(testModelAdd);
1151 if (checkEarlyTermination(modelAdd)) return;
1152
1153 // Save the modelMul compilation to cache.
1154 auto modelCacheMul = mModelCache;
1155 for (auto& cache : modelCacheMul) {
1156 cache[0].append("_mul");
1157 }
1158 {
1159 hidl_vec<hidl_handle> modelCache, dataCache;
1160 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1161 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1162 saveModelToCache(modelMul, modelCache, dataCache);
1163 }
1164
1165 // Use a different token for modelAdd.
1166 mToken[0]++;
1167
1168 // Save the modelAdd compilation to cache.
1169 {
1170 hidl_vec<hidl_handle> modelCache, dataCache;
1171 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1172 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1173 saveModelToCache(modelAdd, modelCache, dataCache);
1174 }
1175
1176 // Replace the model cache of modelAdd with modelMul.
1177 copyCacheFiles(modelCacheMul, mModelCache);
1178
1179 // Retrieve the preparedModel from cache, expect failure.
1180 {
1181 sp<IPreparedModel> preparedModel = nullptr;
1182 ErrorStatus status;
1183 hidl_vec<hidl_handle> modelCache, dataCache;
1184 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1185 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1186 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1187 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1188 ASSERT_EQ(preparedModel, nullptr);
1189 }
1190 }
1191
1192 static const auto kNamedDeviceChoices = testing::ValuesIn(getNamedDevices());
1193 static const auto kOperandTypeChoices =
1194 testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1195
printCompilationCachingTest(const testing::TestParamInfo<CompilationCachingTestParam> & info)1196 std::string printCompilationCachingTest(
1197 const testing::TestParamInfo<CompilationCachingTestParam>& info) {
1198 const auto& [namedDevice, operandType] = info.param;
1199 const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1200 return gtestCompliantName(getName(namedDevice) + "_" + type);
1201 }
1202
1203 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest,
1204 testing::Combine(kNamedDeviceChoices, kOperandTypeChoices),
1205 printCompilationCachingTest);
1206
1207 using CompilationCachingSecurityTestParam = std::tuple<NamedDevice, OperandType, uint32_t>;
1208
1209 class CompilationCachingSecurityTest
1210 : public CompilationCachingTestBase,
1211 public testing::WithParamInterface<CompilationCachingSecurityTestParam> {
1212 protected:
CompilationCachingSecurityTest()1213 CompilationCachingSecurityTest()
1214 : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
1215 std::get<OperandType>(GetParam())) {}
1216
SetUp()1217 void SetUp() {
1218 CompilationCachingTestBase::SetUp();
1219 generator.seed(kSeed);
1220 }
1221
1222 // Get a random integer within a closed range [lower, upper].
1223 template <typename T>
getRandomInt(T lower,T upper)1224 T getRandomInt(T lower, T upper) {
1225 std::uniform_int_distribution<T> dis(lower, upper);
1226 return dis(generator);
1227 }
1228
1229 // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1230 void flipOneBitOfCache(const std::string& filename, bool* skip) {
1231 FILE* pFile = fopen(filename.c_str(), "r+");
1232 ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1233 long int fileSize = ftell(pFile);
1234 if (fileSize == 0) {
1235 fclose(pFile);
1236 *skip = true;
1237 return;
1238 }
1239 ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1240 int readByte = fgetc(pFile);
1241 ASSERT_NE(readByte, EOF);
1242 ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1243 ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1244 fclose(pFile);
1245 *skip = false;
1246 }
1247
1248 // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1249 void appendBytesToCache(const std::string& filename, bool* skip) {
1250 FILE* pFile = fopen(filename.c_str(), "a");
1251 uint32_t appendLength = getRandomInt(1, 256);
1252 for (uint32_t i = 0; i < appendLength; i++) {
1253 ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
1254 }
1255 fclose(pFile);
1256 *skip = false;
1257 }
1258
1259 enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1260
1261 // Test if the driver behaves as expected when given corrupted cache or token.
1262 // The modifier will be invoked after save to cache but before prepare from cache.
1263 // The modifier accepts one pointer argument "skip" as the returning value, indicating
1264 // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1265 void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1266 const TestModel& testModel = createTestModel();
1267 const Model model = createModel(testModel);
1268 if (checkEarlyTermination(model)) return;
1269
1270 // Save the compilation to cache.
1271 {
1272 hidl_vec<hidl_handle> modelCache, dataCache;
1273 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1274 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1275 saveModelToCache(model, modelCache, dataCache);
1276 }
1277
1278 bool skip = false;
1279 modifier(&skip);
1280 if (skip) return;
1281
1282 // Retrieve preparedModel from cache.
1283 {
1284 sp<IPreparedModel> preparedModel = nullptr;
1285 ErrorStatus status;
1286 hidl_vec<hidl_handle> modelCache, dataCache;
1287 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1288 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1289 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1290
1291 switch (expected) {
1292 case ExpectedResult::GENERAL_FAILURE:
1293 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1294 ASSERT_EQ(preparedModel, nullptr);
1295 break;
1296 case ExpectedResult::NOT_CRASH:
1297 ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1298 break;
1299 default:
1300 FAIL();
1301 }
1302 }
1303 }
1304
1305 const uint32_t kSeed = std::get<uint32_t>(GetParam());
1306 std::mt19937 generator;
1307 };
1308
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1309 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1310 if (!mIsCachingSupported) return;
1311 for (uint32_t i = 0; i < mNumModelCache; i++) {
1312 testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1313 [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
1314 }
1315 }
1316
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1317 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1318 if (!mIsCachingSupported) return;
1319 for (uint32_t i = 0; i < mNumModelCache; i++) {
1320 testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1321 [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
1322 }
1323 }
1324
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1325 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1326 if (!mIsCachingSupported) return;
1327 for (uint32_t i = 0; i < mNumDataCache; i++) {
1328 testCorruptedCache(ExpectedResult::NOT_CRASH,
1329 [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
1330 }
1331 }
1332
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1333 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1334 if (!mIsCachingSupported) return;
1335 for (uint32_t i = 0; i < mNumDataCache; i++) {
1336 testCorruptedCache(ExpectedResult::NOT_CRASH,
1337 [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
1338 }
1339 }
1340
TEST_P(CompilationCachingSecurityTest,WrongToken)1341 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1342 if (!mIsCachingSupported) return;
1343 testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1344 // Randomly flip one single bit in mToken.
1345 uint32_t ind =
1346 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1347 mToken[ind] ^= (1U << getRandomInt(0, 7));
1348 *skip = false;
1349 });
1350 }
1351
printCompilationCachingSecurityTest(const testing::TestParamInfo<CompilationCachingSecurityTestParam> & info)1352 std::string printCompilationCachingSecurityTest(
1353 const testing::TestParamInfo<CompilationCachingSecurityTestParam>& info) {
1354 const auto& [namedDevice, operandType, seed] = info.param;
1355 const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1356 return gtestCompliantName(getName(namedDevice) + "_" + type + "_" + std::to_string(seed));
1357 }
1358
1359 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1360 testing::Combine(kNamedDeviceChoices, kOperandTypeChoices,
1361 testing::Range(0U, 10U)),
1362 printCompilationCachingSecurityTest);
1363
1364 } // namespace android::hardware::neuralnetworks::V1_3::vts::functional
1365