1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include <android-base/logging.h>
20 #include <fcntl.h>
21 #include <ftw.h>
22 #include <gtest/gtest.h>
23 #include <hidlmemory/mapping.h>
24 #include <unistd.h>
25 
26 #include <cstdio>
27 #include <cstdlib>
28 #include <random>
29 #include <thread>
30 
31 #include "1.3/Callbacks.h"
32 #include "1.3/Utils.h"
33 #include "GeneratedTestHarness.h"
34 #include "MemoryUtils.h"
35 #include "TestHarness.h"
36 #include "Utils.h"
37 #include "VtsHalNeuralnetworks.h"
38 
39 // Forward declaration of the mobilenet generated test models in
40 // frameworks/ml/nn/runtime/test/generated/.
41 namespace generated_tests::mobilenet_224_gender_basic_fixed {
42 const test_helper::TestModel& get_test_model();
43 }  // namespace generated_tests::mobilenet_224_gender_basic_fixed
44 
45 namespace generated_tests::mobilenet_quantized {
46 const test_helper::TestModel& get_test_model();
47 }  // namespace generated_tests::mobilenet_quantized
48 
49 namespace android::hardware::neuralnetworks::V1_3::vts::functional {
50 
51 using namespace test_helper;
52 using implementation::PreparedModelCallback;
53 using V1_1::ExecutionPreference;
54 using V1_2::Constant;
55 using V1_2::OperationType;
56 
57 namespace float32_model {
58 
59 constexpr auto get_test_model = generated_tests::mobilenet_224_gender_basic_fixed::get_test_model;
60 
61 }  // namespace float32_model
62 
63 namespace quant8_model {
64 
65 constexpr auto get_test_model = generated_tests::mobilenet_quantized::get_test_model;
66 
67 }  // namespace quant8_model
68 
69 namespace {
70 
71 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
72 
73 // Creates cache handles based on provided file groups.
74 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,const std::vector<AccessMode> & mode,hidl_vec<hidl_handle> * handles)75 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
76                         const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
77     handles->resize(fileGroups.size());
78     for (uint32_t i = 0; i < fileGroups.size(); i++) {
79         std::vector<int> fds;
80         for (const auto& file : fileGroups[i]) {
81             int fd;
82             if (mode[i] == AccessMode::READ_ONLY) {
83                 fd = open(file.c_str(), O_RDONLY);
84             } else if (mode[i] == AccessMode::WRITE_ONLY) {
85                 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
86             } else if (mode[i] == AccessMode::READ_WRITE) {
87                 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
88             } else {
89                 FAIL();
90             }
91             ASSERT_GE(fd, 0);
92             fds.push_back(fd);
93         }
94         native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
95         ASSERT_NE(cacheNativeHandle, nullptr);
96         std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
97         (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
98     }
99 }
100 
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,AccessMode mode,hidl_vec<hidl_handle> * handles)101 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
102                         hidl_vec<hidl_handle>* handles) {
103     createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
104 }
105 
106 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
107 // For simplicity, activation scalar is shared. The second operand is not shared
108 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
109 // data locations in cache.
110 //
111 //                --------- activation --------
112 //                ↓      ↓      ↓             ↓
113 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
114 //                ↑      ↑      ↑             ↑
115 //               [1]    [1]    [1]           [1]
116 //
117 // This function assumes the operation is either ADD or MUL.
118 template <typename CppType, TestOperandType operandType>
createLargeTestModelImpl(TestOperationType op,uint32_t len)119 TestModel createLargeTestModelImpl(TestOperationType op, uint32_t len) {
120     EXPECT_TRUE(op == TestOperationType::ADD || op == TestOperationType::MUL);
121 
122     // Model operations and operands.
123     std::vector<TestOperation> operations(len);
124     std::vector<TestOperand> operands(len * 2 + 2);
125 
126     // The activation scalar, value = 0.
127     operands[0] = {
128             .type = TestOperandType::INT32,
129             .dimensions = {},
130             .numberOfConsumers = len,
131             .scale = 0.0f,
132             .zeroPoint = 0,
133             .lifetime = TestOperandLifeTime::CONSTANT_COPY,
134             .data = TestBuffer::createFromVector<int32_t>({0}),
135     };
136 
137     // The buffer value of the constant second operand. The logical value is always 1.0f.
138     CppType bufferValue;
139     // The scale of the first and second operand.
140     float scale1, scale2;
141     if (operandType == TestOperandType::TENSOR_FLOAT32) {
142         bufferValue = 1.0f;
143         scale1 = 0.0f;
144         scale2 = 0.0f;
145     } else if (op == TestOperationType::ADD) {
146         bufferValue = 1;
147         scale1 = 1.0f;
148         scale2 = 1.0f;
149     } else {
150         // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
151         // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
152         bufferValue = 2;
153         scale1 = 1.0f;
154         scale2 = 0.5f;
155     }
156 
157     for (uint32_t i = 0; i < len; i++) {
158         const uint32_t firstInputIndex = i * 2 + 1;
159         const uint32_t secondInputIndex = firstInputIndex + 1;
160         const uint32_t outputIndex = secondInputIndex + 1;
161 
162         // The first operation input.
163         operands[firstInputIndex] = {
164                 .type = operandType,
165                 .dimensions = {1},
166                 .numberOfConsumers = 1,
167                 .scale = scale1,
168                 .zeroPoint = 0,
169                 .lifetime = (i == 0 ? TestOperandLifeTime::MODEL_INPUT
170                                     : TestOperandLifeTime::TEMPORARY_VARIABLE),
171                 .data = (i == 0 ? TestBuffer::createFromVector<CppType>({1}) : TestBuffer()),
172         };
173 
174         // The second operation input, value = 1.
175         operands[secondInputIndex] = {
176                 .type = operandType,
177                 .dimensions = {1},
178                 .numberOfConsumers = 1,
179                 .scale = scale2,
180                 .zeroPoint = 0,
181                 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
182                 .data = TestBuffer::createFromVector<CppType>({bufferValue}),
183         };
184 
185         // The operation. All operations share the same activation scalar.
186         // The output operand is created as an input in the next iteration of the loop, in the case
187         // of all but the last member of the chain; and after the loop as a model output, in the
188         // case of the last member of the chain.
189         operations[i] = {
190                 .type = op,
191                 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
192                 .outputs = {outputIndex},
193         };
194     }
195 
196     // For TestOperationType::ADD, output = 1 + 1 * len = len + 1
197     // For TestOperationType::MUL, output = 1 * 1 ^ len = 1
198     CppType outputResult = static_cast<CppType>(op == TestOperationType::ADD ? len + 1u : 1u);
199 
200     // The model output.
201     operands.back() = {
202             .type = operandType,
203             .dimensions = {1},
204             .numberOfConsumers = 0,
205             .scale = scale1,
206             .zeroPoint = 0,
207             .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
208             .data = TestBuffer::createFromVector<CppType>({outputResult}),
209     };
210 
211     return {
212             .main = {.operands = std::move(operands),
213                      .operations = std::move(operations),
214                      .inputIndexes = {1},
215                      .outputIndexes = {len * 2 + 1}},
216             .isRelaxed = false,
217     };
218 }
219 
220 }  // namespace
221 
222 // Tag for the compilation caching tests.
223 class CompilationCachingTestBase : public testing::Test {
224   protected:
CompilationCachingTestBase(sp<IDevice> device,OperandType type)225     CompilationCachingTestBase(sp<IDevice> device, OperandType type)
226         : kDevice(std::move(device)), kOperandType(type) {}
227 
SetUp()228     void SetUp() override {
229         testing::Test::SetUp();
230         ASSERT_NE(kDevice.get(), nullptr);
231 
232         // Create cache directory. The cache directory and a temporary cache file is always created
233         // to test the behavior of prepareModelFromCache_1_3, even when caching is not supported.
234         char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
235         char* cacheDir = mkdtemp(cacheDirTemp);
236         ASSERT_NE(cacheDir, nullptr);
237         mCacheDir = cacheDir;
238         mCacheDir.push_back('/');
239 
240         Return<void> ret = kDevice->getNumberOfCacheFilesNeeded(
241                 [this](V1_0::ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
242                     EXPECT_EQ(V1_0::ErrorStatus::NONE, status);
243                     mNumModelCache = numModelCache;
244                     mNumDataCache = numDataCache;
245                 });
246         EXPECT_TRUE(ret.isOk());
247         mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
248 
249         // Create empty cache files.
250         mTmpCache = mCacheDir + "tmp";
251         for (uint32_t i = 0; i < mNumModelCache; i++) {
252             mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
253         }
254         for (uint32_t i = 0; i < mNumDataCache; i++) {
255             mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
256         }
257         // Dummy handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
258         hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
259         createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
260         createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
261         createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
262 
263         if (!mIsCachingSupported) {
264             LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
265                          "support compilation caching.";
266             std::cout << "[          ]   Early termination of test because vendor service does not "
267                          "support compilation caching."
268                       << std::endl;
269         }
270     }
271 
TearDown()272     void TearDown() override {
273         // If the test passes, remove the tmp directory.  Otherwise, keep it for debugging purposes.
274         if (!testing::Test::HasFailure()) {
275             // Recursively remove the cache directory specified by mCacheDir.
276             auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
277                 return remove(entry);
278             };
279             nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
280         }
281         testing::Test::TearDown();
282     }
283 
284     // Model and examples creators. According to kOperandType, the following methods will return
285     // either float32 model/examples or the quant8 variant.
createTestModel()286     TestModel createTestModel() {
287         if (kOperandType == OperandType::TENSOR_FLOAT32) {
288             return float32_model::get_test_model();
289         } else {
290             return quant8_model::get_test_model();
291         }
292     }
293 
createLargeTestModel(OperationType op,uint32_t len)294     TestModel createLargeTestModel(OperationType op, uint32_t len) {
295         if (kOperandType == OperandType::TENSOR_FLOAT32) {
296             return createLargeTestModelImpl<float, TestOperandType::TENSOR_FLOAT32>(
297                     static_cast<TestOperationType>(op), len);
298         } else {
299             return createLargeTestModelImpl<uint8_t, TestOperandType::TENSOR_QUANT8_ASYMM>(
300                     static_cast<TestOperationType>(op), len);
301         }
302     }
303 
304     // See if the service can handle the model.
isModelFullySupported(const Model & model)305     bool isModelFullySupported(const Model& model) {
306         bool fullySupportsModel = false;
307         Return<void> supportedCall = kDevice->getSupportedOperations_1_3(
308                 model,
309                 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
310                     ASSERT_EQ(ErrorStatus::NONE, status);
311                     ASSERT_EQ(supported.size(), model.main.operations.size());
312                     fullySupportsModel = std::all_of(supported.begin(), supported.end(),
313                                                      [](bool valid) { return valid; });
314                 });
315         EXPECT_TRUE(supportedCall.isOk());
316         return fullySupportsModel;
317     }
318 
saveModelToCache(const Model & model,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel=nullptr)319     void saveModelToCache(const Model& model, const hidl_vec<hidl_handle>& modelCache,
320                           const hidl_vec<hidl_handle>& dataCache,
321                           sp<IPreparedModel>* preparedModel = nullptr) {
322         if (preparedModel != nullptr) *preparedModel = nullptr;
323 
324         // Launch prepare model.
325         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
326         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
327         Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModel_1_3(
328                 model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, {}, modelCache,
329                 dataCache, cacheToken, preparedModelCallback);
330         ASSERT_TRUE(prepareLaunchStatus.isOk());
331         ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
332 
333         // Retrieve prepared model.
334         preparedModelCallback->wait();
335         ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
336         if (preparedModel != nullptr) {
337             *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
338                                      .withDefault(nullptr);
339         }
340     }
341 
checkEarlyTermination(ErrorStatus status)342     bool checkEarlyTermination(ErrorStatus status) {
343         if (status == ErrorStatus::GENERAL_FAILURE) {
344             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
345                          "save the prepared model that it does not support.";
346             std::cout << "[          ]   Early termination of test because vendor service cannot "
347                          "save the prepared model that it does not support."
348                       << std::endl;
349             return true;
350         }
351         return false;
352     }
353 
checkEarlyTermination(const Model & model)354     bool checkEarlyTermination(const Model& model) {
355         if (!isModelFullySupported(model)) {
356             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
357                          "prepare model that it does not support.";
358             std::cout << "[          ]   Early termination of test because vendor service cannot "
359                          "prepare model that it does not support."
360                       << std::endl;
361             return true;
362         }
363         return false;
364     }
365 
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel,ErrorStatus * status)366     void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
367                                const hidl_vec<hidl_handle>& dataCache,
368                                sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
369         // Launch prepare model from cache.
370         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
371         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
372         Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModelFromCache_1_3(
373                 {}, modelCache, dataCache, cacheToken, preparedModelCallback);
374         ASSERT_TRUE(prepareLaunchStatus.isOk());
375         if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
376             *preparedModel = nullptr;
377             *status = static_cast<ErrorStatus>(prepareLaunchStatus);
378             return;
379         }
380 
381         // Retrieve prepared model.
382         preparedModelCallback->wait();
383         *status = preparedModelCallback->getStatus();
384         *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
385                                  .withDefault(nullptr);
386     }
387 
388     // Absolute path to the temporary cache directory.
389     std::string mCacheDir;
390 
391     // Groups of file paths for model and data cache in the tmp cache directory, initialized with
392     // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
393     // and the inner vector is for fds held by each handle.
394     std::vector<std::vector<std::string>> mModelCache;
395     std::vector<std::vector<std::string>> mDataCache;
396 
397     // A separate temporary file path in the tmp cache directory.
398     std::string mTmpCache;
399 
400     uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
401     uint32_t mNumModelCache;
402     uint32_t mNumDataCache;
403     uint32_t mIsCachingSupported;
404 
405     const sp<IDevice> kDevice;
406     // The primary data type of the testModel.
407     const OperandType kOperandType;
408 };
409 
410 using CompilationCachingTestParam = std::tuple<NamedDevice, OperandType>;
411 
412 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
413 // pass running with float32 models and the second pass running with quant8 models.
414 class CompilationCachingTest : public CompilationCachingTestBase,
415                                public testing::WithParamInterface<CompilationCachingTestParam> {
416   protected:
CompilationCachingTest()417     CompilationCachingTest()
418         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
419                                      std::get<OperandType>(GetParam())) {}
420 };
421 
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)422 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
423     // Create test HIDL model and compile.
424     const TestModel& testModel = createTestModel();
425     const Model model = createModel(testModel);
426     if (checkEarlyTermination(model)) return;
427     sp<IPreparedModel> preparedModel = nullptr;
428 
429     // Save the compilation to cache.
430     {
431         hidl_vec<hidl_handle> modelCache, dataCache;
432         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
433         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
434         saveModelToCache(model, modelCache, dataCache);
435     }
436 
437     // Retrieve preparedModel from cache.
438     {
439         preparedModel = nullptr;
440         ErrorStatus status;
441         hidl_vec<hidl_handle> modelCache, dataCache;
442         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
443         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
444         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
445         if (!mIsCachingSupported) {
446             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
447             ASSERT_EQ(preparedModel, nullptr);
448             return;
449         } else if (checkEarlyTermination(status)) {
450             ASSERT_EQ(preparedModel, nullptr);
451             return;
452         } else {
453             ASSERT_EQ(status, ErrorStatus::NONE);
454             ASSERT_NE(preparedModel, nullptr);
455         }
456     }
457 
458     // Execute and verify results.
459     EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
460 }
461 
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)462 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
463     // Create test HIDL model and compile.
464     const TestModel& testModel = createTestModel();
465     const Model model = createModel(testModel);
466     if (checkEarlyTermination(model)) return;
467     sp<IPreparedModel> preparedModel = nullptr;
468 
469     // Save the compilation to cache.
470     {
471         hidl_vec<hidl_handle> modelCache, dataCache;
472         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
473         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
474         uint8_t dummyBytes[] = {0, 0};
475         // Write a dummy integer to the cache.
476         // The driver should be able to handle non-empty cache and non-zero fd offset.
477         for (uint32_t i = 0; i < modelCache.size(); i++) {
478             ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &dummyBytes,
479                             sizeof(dummyBytes)),
480                       sizeof(dummyBytes));
481         }
482         for (uint32_t i = 0; i < dataCache.size(); i++) {
483             ASSERT_EQ(
484                     write(dataCache[i].getNativeHandle()->data[0], &dummyBytes, sizeof(dummyBytes)),
485                     sizeof(dummyBytes));
486         }
487         saveModelToCache(model, modelCache, dataCache);
488     }
489 
490     // Retrieve preparedModel from cache.
491     {
492         preparedModel = nullptr;
493         ErrorStatus status;
494         hidl_vec<hidl_handle> modelCache, dataCache;
495         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
496         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
497         uint8_t dummyByte = 0;
498         // Advance the offset of each handle by one byte.
499         // The driver should be able to handle non-zero fd offset.
500         for (uint32_t i = 0; i < modelCache.size(); i++) {
501             ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
502         }
503         for (uint32_t i = 0; i < dataCache.size(); i++) {
504             ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
505         }
506         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
507         if (!mIsCachingSupported) {
508             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
509             ASSERT_EQ(preparedModel, nullptr);
510             return;
511         } else if (checkEarlyTermination(status)) {
512             ASSERT_EQ(preparedModel, nullptr);
513             return;
514         } else {
515             ASSERT_EQ(status, ErrorStatus::NONE);
516             ASSERT_NE(preparedModel, nullptr);
517         }
518     }
519 
520     // Execute and verify results.
521     EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
522 }
523 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)524 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
525     // Create test HIDL model and compile.
526     const TestModel& testModel = createTestModel();
527     const Model model = createModel(testModel);
528     if (checkEarlyTermination(model)) return;
529 
530     // Test with number of model cache files greater than mNumModelCache.
531     {
532         hidl_vec<hidl_handle> modelCache, dataCache;
533         // Pass an additional cache file for model cache.
534         mModelCache.push_back({mTmpCache});
535         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
536         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
537         mModelCache.pop_back();
538         sp<IPreparedModel> preparedModel = nullptr;
539         saveModelToCache(model, modelCache, dataCache, &preparedModel);
540         ASSERT_NE(preparedModel, nullptr);
541         // Execute and verify results.
542         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
543         // Check if prepareModelFromCache fails.
544         preparedModel = nullptr;
545         ErrorStatus status;
546         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
547         if (status != ErrorStatus::INVALID_ARGUMENT) {
548             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
549         }
550         ASSERT_EQ(preparedModel, nullptr);
551     }
552 
553     // Test with number of model cache files smaller than mNumModelCache.
554     if (mModelCache.size() > 0) {
555         hidl_vec<hidl_handle> modelCache, dataCache;
556         // Pop out the last cache file.
557         auto tmp = mModelCache.back();
558         mModelCache.pop_back();
559         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
560         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
561         mModelCache.push_back(tmp);
562         sp<IPreparedModel> preparedModel = nullptr;
563         saveModelToCache(model, modelCache, dataCache, &preparedModel);
564         ASSERT_NE(preparedModel, nullptr);
565         // Execute and verify results.
566         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
567         // Check if prepareModelFromCache fails.
568         preparedModel = nullptr;
569         ErrorStatus status;
570         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
571         if (status != ErrorStatus::INVALID_ARGUMENT) {
572             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
573         }
574         ASSERT_EQ(preparedModel, nullptr);
575     }
576 
577     // Test with number of data cache files greater than mNumDataCache.
578     {
579         hidl_vec<hidl_handle> modelCache, dataCache;
580         // Pass an additional cache file for data cache.
581         mDataCache.push_back({mTmpCache});
582         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
583         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
584         mDataCache.pop_back();
585         sp<IPreparedModel> preparedModel = nullptr;
586         saveModelToCache(model, modelCache, dataCache, &preparedModel);
587         ASSERT_NE(preparedModel, nullptr);
588         // Execute and verify results.
589         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
590         // Check if prepareModelFromCache fails.
591         preparedModel = nullptr;
592         ErrorStatus status;
593         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
594         if (status != ErrorStatus::INVALID_ARGUMENT) {
595             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
596         }
597         ASSERT_EQ(preparedModel, nullptr);
598     }
599 
600     // Test with number of data cache files smaller than mNumDataCache.
601     if (mDataCache.size() > 0) {
602         hidl_vec<hidl_handle> modelCache, dataCache;
603         // Pop out the last cache file.
604         auto tmp = mDataCache.back();
605         mDataCache.pop_back();
606         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
607         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
608         mDataCache.push_back(tmp);
609         sp<IPreparedModel> preparedModel = nullptr;
610         saveModelToCache(model, modelCache, dataCache, &preparedModel);
611         ASSERT_NE(preparedModel, nullptr);
612         // Execute and verify results.
613         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
614         // Check if prepareModelFromCache fails.
615         preparedModel = nullptr;
616         ErrorStatus status;
617         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
618         if (status != ErrorStatus::INVALID_ARGUMENT) {
619             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
620         }
621         ASSERT_EQ(preparedModel, nullptr);
622     }
623 }
624 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)625 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
626     // Create test HIDL model and compile.
627     const TestModel& testModel = createTestModel();
628     const Model model = createModel(testModel);
629     if (checkEarlyTermination(model)) return;
630 
631     // Save the compilation to cache.
632     {
633         hidl_vec<hidl_handle> modelCache, dataCache;
634         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
635         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
636         saveModelToCache(model, modelCache, dataCache);
637     }
638 
639     // Test with number of model cache files greater than mNumModelCache.
640     {
641         sp<IPreparedModel> preparedModel = nullptr;
642         ErrorStatus status;
643         hidl_vec<hidl_handle> modelCache, dataCache;
644         mModelCache.push_back({mTmpCache});
645         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
646         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
647         mModelCache.pop_back();
648         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
649         if (status != ErrorStatus::GENERAL_FAILURE) {
650             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
651         }
652         ASSERT_EQ(preparedModel, nullptr);
653     }
654 
655     // Test with number of model cache files smaller than mNumModelCache.
656     if (mModelCache.size() > 0) {
657         sp<IPreparedModel> preparedModel = nullptr;
658         ErrorStatus status;
659         hidl_vec<hidl_handle> modelCache, dataCache;
660         auto tmp = mModelCache.back();
661         mModelCache.pop_back();
662         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
663         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
664         mModelCache.push_back(tmp);
665         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
666         if (status != ErrorStatus::GENERAL_FAILURE) {
667             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
668         }
669         ASSERT_EQ(preparedModel, nullptr);
670     }
671 
672     // Test with number of data cache files greater than mNumDataCache.
673     {
674         sp<IPreparedModel> preparedModel = nullptr;
675         ErrorStatus status;
676         hidl_vec<hidl_handle> modelCache, dataCache;
677         mDataCache.push_back({mTmpCache});
678         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
679         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
680         mDataCache.pop_back();
681         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
682         if (status != ErrorStatus::GENERAL_FAILURE) {
683             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
684         }
685         ASSERT_EQ(preparedModel, nullptr);
686     }
687 
688     // Test with number of data cache files smaller than mNumDataCache.
689     if (mDataCache.size() > 0) {
690         sp<IPreparedModel> preparedModel = nullptr;
691         ErrorStatus status;
692         hidl_vec<hidl_handle> modelCache, dataCache;
693         auto tmp = mDataCache.back();
694         mDataCache.pop_back();
695         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
696         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
697         mDataCache.push_back(tmp);
698         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
699         if (status != ErrorStatus::GENERAL_FAILURE) {
700             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
701         }
702         ASSERT_EQ(preparedModel, nullptr);
703     }
704 }
705 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumFd)706 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
707     // Create test HIDL model and compile.
708     const TestModel& testModel = createTestModel();
709     const Model model = createModel(testModel);
710     if (checkEarlyTermination(model)) return;
711 
712     // Go through each handle in model cache, test with NumFd greater than 1.
713     for (uint32_t i = 0; i < mNumModelCache; i++) {
714         hidl_vec<hidl_handle> modelCache, dataCache;
715         // Pass an invalid number of fds for handle i.
716         mModelCache[i].push_back(mTmpCache);
717         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
718         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
719         mModelCache[i].pop_back();
720         sp<IPreparedModel> preparedModel = nullptr;
721         saveModelToCache(model, modelCache, dataCache, &preparedModel);
722         ASSERT_NE(preparedModel, nullptr);
723         // Execute and verify results.
724         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
725         // Check if prepareModelFromCache fails.
726         preparedModel = nullptr;
727         ErrorStatus status;
728         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
729         if (status != ErrorStatus::INVALID_ARGUMENT) {
730             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
731         }
732         ASSERT_EQ(preparedModel, nullptr);
733     }
734 
735     // Go through each handle in model cache, test with NumFd equal to 0.
736     for (uint32_t i = 0; i < mNumModelCache; i++) {
737         hidl_vec<hidl_handle> modelCache, dataCache;
738         // Pass an invalid number of fds for handle i.
739         auto tmp = mModelCache[i].back();
740         mModelCache[i].pop_back();
741         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
742         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
743         mModelCache[i].push_back(tmp);
744         sp<IPreparedModel> preparedModel = nullptr;
745         saveModelToCache(model, modelCache, dataCache, &preparedModel);
746         ASSERT_NE(preparedModel, nullptr);
747         // Execute and verify results.
748         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
749         // Check if prepareModelFromCache fails.
750         preparedModel = nullptr;
751         ErrorStatus status;
752         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
753         if (status != ErrorStatus::INVALID_ARGUMENT) {
754             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
755         }
756         ASSERT_EQ(preparedModel, nullptr);
757     }
758 
759     // Go through each handle in data cache, test with NumFd greater than 1.
760     for (uint32_t i = 0; i < mNumDataCache; i++) {
761         hidl_vec<hidl_handle> modelCache, dataCache;
762         // Pass an invalid number of fds for handle i.
763         mDataCache[i].push_back(mTmpCache);
764         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
765         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
766         mDataCache[i].pop_back();
767         sp<IPreparedModel> preparedModel = nullptr;
768         saveModelToCache(model, modelCache, dataCache, &preparedModel);
769         ASSERT_NE(preparedModel, nullptr);
770         // Execute and verify results.
771         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
772         // Check if prepareModelFromCache fails.
773         preparedModel = nullptr;
774         ErrorStatus status;
775         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
776         if (status != ErrorStatus::INVALID_ARGUMENT) {
777             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
778         }
779         ASSERT_EQ(preparedModel, nullptr);
780     }
781 
782     // Go through each handle in data cache, test with NumFd equal to 0.
783     for (uint32_t i = 0; i < mNumDataCache; i++) {
784         hidl_vec<hidl_handle> modelCache, dataCache;
785         // Pass an invalid number of fds for handle i.
786         auto tmp = mDataCache[i].back();
787         mDataCache[i].pop_back();
788         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
789         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
790         mDataCache[i].push_back(tmp);
791         sp<IPreparedModel> preparedModel = nullptr;
792         saveModelToCache(model, modelCache, dataCache, &preparedModel);
793         ASSERT_NE(preparedModel, nullptr);
794         // Execute and verify results.
795         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
796         // Check if prepareModelFromCache fails.
797         preparedModel = nullptr;
798         ErrorStatus status;
799         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
800         if (status != ErrorStatus::INVALID_ARGUMENT) {
801             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
802         }
803         ASSERT_EQ(preparedModel, nullptr);
804     }
805 }
806 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumFd)807 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
808     // Create test HIDL model and compile.
809     const TestModel& testModel = createTestModel();
810     const Model model = createModel(testModel);
811     if (checkEarlyTermination(model)) return;
812 
813     // Save the compilation to cache.
814     {
815         hidl_vec<hidl_handle> modelCache, dataCache;
816         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
817         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
818         saveModelToCache(model, modelCache, dataCache);
819     }
820 
821     // Go through each handle in model cache, test with NumFd greater than 1.
822     for (uint32_t i = 0; i < mNumModelCache; i++) {
823         sp<IPreparedModel> preparedModel = nullptr;
824         ErrorStatus status;
825         hidl_vec<hidl_handle> modelCache, dataCache;
826         mModelCache[i].push_back(mTmpCache);
827         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
828         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
829         mModelCache[i].pop_back();
830         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
831         if (status != ErrorStatus::GENERAL_FAILURE) {
832             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
833         }
834         ASSERT_EQ(preparedModel, nullptr);
835     }
836 
837     // Go through each handle in model cache, test with NumFd equal to 0.
838     for (uint32_t i = 0; i < mNumModelCache; i++) {
839         sp<IPreparedModel> preparedModel = nullptr;
840         ErrorStatus status;
841         hidl_vec<hidl_handle> modelCache, dataCache;
842         auto tmp = mModelCache[i].back();
843         mModelCache[i].pop_back();
844         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
845         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
846         mModelCache[i].push_back(tmp);
847         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
848         if (status != ErrorStatus::GENERAL_FAILURE) {
849             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
850         }
851         ASSERT_EQ(preparedModel, nullptr);
852     }
853 
854     // Go through each handle in data cache, test with NumFd greater than 1.
855     for (uint32_t i = 0; i < mNumDataCache; i++) {
856         sp<IPreparedModel> preparedModel = nullptr;
857         ErrorStatus status;
858         hidl_vec<hidl_handle> modelCache, dataCache;
859         mDataCache[i].push_back(mTmpCache);
860         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
861         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
862         mDataCache[i].pop_back();
863         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
864         if (status != ErrorStatus::GENERAL_FAILURE) {
865             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
866         }
867         ASSERT_EQ(preparedModel, nullptr);
868     }
869 
870     // Go through each handle in data cache, test with NumFd equal to 0.
871     for (uint32_t i = 0; i < mNumDataCache; i++) {
872         sp<IPreparedModel> preparedModel = nullptr;
873         ErrorStatus status;
874         hidl_vec<hidl_handle> modelCache, dataCache;
875         auto tmp = mDataCache[i].back();
876         mDataCache[i].pop_back();
877         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
878         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
879         mDataCache[i].push_back(tmp);
880         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
881         if (status != ErrorStatus::GENERAL_FAILURE) {
882             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
883         }
884         ASSERT_EQ(preparedModel, nullptr);
885     }
886 }
887 
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)888 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
889     // Create test HIDL model and compile.
890     const TestModel& testModel = createTestModel();
891     const Model model = createModel(testModel);
892     if (checkEarlyTermination(model)) return;
893     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
894     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
895 
896     // Go through each handle in model cache, test with invalid access mode.
897     for (uint32_t i = 0; i < mNumModelCache; i++) {
898         hidl_vec<hidl_handle> modelCache, dataCache;
899         modelCacheMode[i] = AccessMode::READ_ONLY;
900         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
901         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
902         modelCacheMode[i] = AccessMode::READ_WRITE;
903         sp<IPreparedModel> preparedModel = nullptr;
904         saveModelToCache(model, modelCache, dataCache, &preparedModel);
905         ASSERT_NE(preparedModel, nullptr);
906         // Execute and verify results.
907         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
908         // Check if prepareModelFromCache fails.
909         preparedModel = nullptr;
910         ErrorStatus status;
911         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
912         if (status != ErrorStatus::INVALID_ARGUMENT) {
913             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
914         }
915         ASSERT_EQ(preparedModel, nullptr);
916     }
917 
918     // Go through each handle in data cache, test with invalid access mode.
919     for (uint32_t i = 0; i < mNumDataCache; i++) {
920         hidl_vec<hidl_handle> modelCache, dataCache;
921         dataCacheMode[i] = AccessMode::READ_ONLY;
922         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
923         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
924         dataCacheMode[i] = AccessMode::READ_WRITE;
925         sp<IPreparedModel> preparedModel = nullptr;
926         saveModelToCache(model, modelCache, dataCache, &preparedModel);
927         ASSERT_NE(preparedModel, nullptr);
928         // Execute and verify results.
929         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
930         // Check if prepareModelFromCache fails.
931         preparedModel = nullptr;
932         ErrorStatus status;
933         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
934         if (status != ErrorStatus::INVALID_ARGUMENT) {
935             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
936         }
937         ASSERT_EQ(preparedModel, nullptr);
938     }
939 }
940 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)941 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
942     // Create test HIDL model and compile.
943     const TestModel& testModel = createTestModel();
944     const Model model = createModel(testModel);
945     if (checkEarlyTermination(model)) return;
946     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
947     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
948 
949     // Save the compilation to cache.
950     {
951         hidl_vec<hidl_handle> modelCache, dataCache;
952         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
953         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
954         saveModelToCache(model, modelCache, dataCache);
955     }
956 
957     // Go through each handle in model cache, test with invalid access mode.
958     for (uint32_t i = 0; i < mNumModelCache; i++) {
959         sp<IPreparedModel> preparedModel = nullptr;
960         ErrorStatus status;
961         hidl_vec<hidl_handle> modelCache, dataCache;
962         modelCacheMode[i] = AccessMode::WRITE_ONLY;
963         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
964         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
965         modelCacheMode[i] = AccessMode::READ_WRITE;
966         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
967         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
968         ASSERT_EQ(preparedModel, nullptr);
969     }
970 
971     // Go through each handle in data cache, test with invalid access mode.
972     for (uint32_t i = 0; i < mNumDataCache; i++) {
973         sp<IPreparedModel> preparedModel = nullptr;
974         ErrorStatus status;
975         hidl_vec<hidl_handle> modelCache, dataCache;
976         dataCacheMode[i] = AccessMode::WRITE_ONLY;
977         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
978         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
979         dataCacheMode[i] = AccessMode::READ_WRITE;
980         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
981         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
982         ASSERT_EQ(preparedModel, nullptr);
983     }
984 }
985 
986 // Copy file contents between file groups.
987 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
988 // The outer vector sizes must match and the inner vectors must have size = 1.
copyCacheFiles(const std::vector<std::vector<std::string>> & from,const std::vector<std::vector<std::string>> & to)989 static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
990                            const std::vector<std::vector<std::string>>& to) {
991     constexpr size_t kBufferSize = 1000000;
992     uint8_t buffer[kBufferSize];
993 
994     ASSERT_EQ(from.size(), to.size());
995     for (uint32_t i = 0; i < from.size(); i++) {
996         ASSERT_EQ(from[i].size(), 1u);
997         ASSERT_EQ(to[i].size(), 1u);
998         int fromFd = open(from[i][0].c_str(), O_RDONLY);
999         int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
1000         ASSERT_GE(fromFd, 0);
1001         ASSERT_GE(toFd, 0);
1002 
1003         ssize_t readBytes;
1004         while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
1005             ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
1006         }
1007         ASSERT_GE(readBytes, 0);
1008 
1009         close(fromFd);
1010         close(toFd);
1011     }
1012 }
1013 
1014 // Number of operations in the large test model.
1015 constexpr uint32_t kLargeModelSize = 100;
1016 constexpr uint32_t kNumIterationsTOCTOU = 100;
1017 
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)1018 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
1019     if (!mIsCachingSupported) return;
1020 
1021     // Create test models and check if fully supported by the service.
1022     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1023     const Model modelMul = createModel(testModelMul);
1024     if (checkEarlyTermination(modelMul)) return;
1025     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1026     const Model modelAdd = createModel(testModelAdd);
1027     if (checkEarlyTermination(modelAdd)) return;
1028 
1029     // Save the modelMul compilation to cache.
1030     auto modelCacheMul = mModelCache;
1031     for (auto& cache : modelCacheMul) {
1032         cache[0].append("_mul");
1033     }
1034     {
1035         hidl_vec<hidl_handle> modelCache, dataCache;
1036         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1037         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1038         saveModelToCache(modelMul, modelCache, dataCache);
1039     }
1040 
1041     // Use a different token for modelAdd.
1042     mToken[0]++;
1043 
1044     // This test is probabilistic, so we run it multiple times.
1045     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1046         // Save the modelAdd compilation to cache.
1047         {
1048             hidl_vec<hidl_handle> modelCache, dataCache;
1049             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1050             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1051 
1052             // Spawn a thread to copy the cache content concurrently while saving to cache.
1053             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1054             saveModelToCache(modelAdd, modelCache, dataCache);
1055             thread.join();
1056         }
1057 
1058         // Retrieve preparedModel from cache.
1059         {
1060             sp<IPreparedModel> preparedModel = nullptr;
1061             ErrorStatus status;
1062             hidl_vec<hidl_handle> modelCache, dataCache;
1063             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1064             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1065             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1066 
1067             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1068             // the prepared model must be executed with the correct result and not crash.
1069             if (status != ErrorStatus::NONE) {
1070                 ASSERT_EQ(preparedModel, nullptr);
1071             } else {
1072                 ASSERT_NE(preparedModel, nullptr);
1073                 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
1074                                       /*testKind=*/TestKind::GENERAL);
1075             }
1076         }
1077     }
1078 }
1079 
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)1080 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
1081     if (!mIsCachingSupported) return;
1082 
1083     // Create test models and check if fully supported by the service.
1084     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1085     const Model modelMul = createModel(testModelMul);
1086     if (checkEarlyTermination(modelMul)) return;
1087     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1088     const Model modelAdd = createModel(testModelAdd);
1089     if (checkEarlyTermination(modelAdd)) return;
1090 
1091     // Save the modelMul compilation to cache.
1092     auto modelCacheMul = mModelCache;
1093     for (auto& cache : modelCacheMul) {
1094         cache[0].append("_mul");
1095     }
1096     {
1097         hidl_vec<hidl_handle> modelCache, dataCache;
1098         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1099         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1100         saveModelToCache(modelMul, modelCache, dataCache);
1101     }
1102 
1103     // Use a different token for modelAdd.
1104     mToken[0]++;
1105 
1106     // This test is probabilistic, so we run it multiple times.
1107     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1108         // Save the modelAdd compilation to cache.
1109         {
1110             hidl_vec<hidl_handle> modelCache, dataCache;
1111             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1112             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1113             saveModelToCache(modelAdd, modelCache, dataCache);
1114         }
1115 
1116         // Retrieve preparedModel from cache.
1117         {
1118             sp<IPreparedModel> preparedModel = nullptr;
1119             ErrorStatus status;
1120             hidl_vec<hidl_handle> modelCache, dataCache;
1121             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1122             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1123 
1124             // Spawn a thread to copy the cache content concurrently while preparing from cache.
1125             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1126             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1127             thread.join();
1128 
1129             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1130             // the prepared model must be executed with the correct result and not crash.
1131             if (status != ErrorStatus::NONE) {
1132                 ASSERT_EQ(preparedModel, nullptr);
1133             } else {
1134                 ASSERT_NE(preparedModel, nullptr);
1135                 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
1136                                       /*testKind=*/TestKind::GENERAL);
1137             }
1138         }
1139     }
1140 }
1141 
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)1142 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
1143     if (!mIsCachingSupported) return;
1144 
1145     // Create test models and check if fully supported by the service.
1146     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1147     const Model modelMul = createModel(testModelMul);
1148     if (checkEarlyTermination(modelMul)) return;
1149     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1150     const Model modelAdd = createModel(testModelAdd);
1151     if (checkEarlyTermination(modelAdd)) return;
1152 
1153     // Save the modelMul compilation to cache.
1154     auto modelCacheMul = mModelCache;
1155     for (auto& cache : modelCacheMul) {
1156         cache[0].append("_mul");
1157     }
1158     {
1159         hidl_vec<hidl_handle> modelCache, dataCache;
1160         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1161         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1162         saveModelToCache(modelMul, modelCache, dataCache);
1163     }
1164 
1165     // Use a different token for modelAdd.
1166     mToken[0]++;
1167 
1168     // Save the modelAdd compilation to cache.
1169     {
1170         hidl_vec<hidl_handle> modelCache, dataCache;
1171         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1172         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1173         saveModelToCache(modelAdd, modelCache, dataCache);
1174     }
1175 
1176     // Replace the model cache of modelAdd with modelMul.
1177     copyCacheFiles(modelCacheMul, mModelCache);
1178 
1179     // Retrieve the preparedModel from cache, expect failure.
1180     {
1181         sp<IPreparedModel> preparedModel = nullptr;
1182         ErrorStatus status;
1183         hidl_vec<hidl_handle> modelCache, dataCache;
1184         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1185         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1186         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1187         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1188         ASSERT_EQ(preparedModel, nullptr);
1189     }
1190 }
1191 
1192 static const auto kNamedDeviceChoices = testing::ValuesIn(getNamedDevices());
1193 static const auto kOperandTypeChoices =
1194         testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1195 
printCompilationCachingTest(const testing::TestParamInfo<CompilationCachingTestParam> & info)1196 std::string printCompilationCachingTest(
1197         const testing::TestParamInfo<CompilationCachingTestParam>& info) {
1198     const auto& [namedDevice, operandType] = info.param;
1199     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1200     return gtestCompliantName(getName(namedDevice) + "_" + type);
1201 }
1202 
1203 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest,
1204                         testing::Combine(kNamedDeviceChoices, kOperandTypeChoices),
1205                         printCompilationCachingTest);
1206 
1207 using CompilationCachingSecurityTestParam = std::tuple<NamedDevice, OperandType, uint32_t>;
1208 
1209 class CompilationCachingSecurityTest
1210     : public CompilationCachingTestBase,
1211       public testing::WithParamInterface<CompilationCachingSecurityTestParam> {
1212   protected:
CompilationCachingSecurityTest()1213     CompilationCachingSecurityTest()
1214         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
1215                                      std::get<OperandType>(GetParam())) {}
1216 
SetUp()1217     void SetUp() {
1218         CompilationCachingTestBase::SetUp();
1219         generator.seed(kSeed);
1220     }
1221 
1222     // Get a random integer within a closed range [lower, upper].
1223     template <typename T>
getRandomInt(T lower,T upper)1224     T getRandomInt(T lower, T upper) {
1225         std::uniform_int_distribution<T> dis(lower, upper);
1226         return dis(generator);
1227     }
1228 
1229     // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1230     void flipOneBitOfCache(const std::string& filename, bool* skip) {
1231         FILE* pFile = fopen(filename.c_str(), "r+");
1232         ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1233         long int fileSize = ftell(pFile);
1234         if (fileSize == 0) {
1235             fclose(pFile);
1236             *skip = true;
1237             return;
1238         }
1239         ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1240         int readByte = fgetc(pFile);
1241         ASSERT_NE(readByte, EOF);
1242         ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1243         ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1244         fclose(pFile);
1245         *skip = false;
1246     }
1247 
1248     // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1249     void appendBytesToCache(const std::string& filename, bool* skip) {
1250         FILE* pFile = fopen(filename.c_str(), "a");
1251         uint32_t appendLength = getRandomInt(1, 256);
1252         for (uint32_t i = 0; i < appendLength; i++) {
1253             ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
1254         }
1255         fclose(pFile);
1256         *skip = false;
1257     }
1258 
1259     enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1260 
1261     // Test if the driver behaves as expected when given corrupted cache or token.
1262     // The modifier will be invoked after save to cache but before prepare from cache.
1263     // The modifier accepts one pointer argument "skip" as the returning value, indicating
1264     // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1265     void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1266         const TestModel& testModel = createTestModel();
1267         const Model model = createModel(testModel);
1268         if (checkEarlyTermination(model)) return;
1269 
1270         // Save the compilation to cache.
1271         {
1272             hidl_vec<hidl_handle> modelCache, dataCache;
1273             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1274             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1275             saveModelToCache(model, modelCache, dataCache);
1276         }
1277 
1278         bool skip = false;
1279         modifier(&skip);
1280         if (skip) return;
1281 
1282         // Retrieve preparedModel from cache.
1283         {
1284             sp<IPreparedModel> preparedModel = nullptr;
1285             ErrorStatus status;
1286             hidl_vec<hidl_handle> modelCache, dataCache;
1287             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1288             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1289             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1290 
1291             switch (expected) {
1292                 case ExpectedResult::GENERAL_FAILURE:
1293                     ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1294                     ASSERT_EQ(preparedModel, nullptr);
1295                     break;
1296                 case ExpectedResult::NOT_CRASH:
1297                     ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1298                     break;
1299                 default:
1300                     FAIL();
1301             }
1302         }
1303     }
1304 
1305     const uint32_t kSeed = std::get<uint32_t>(GetParam());
1306     std::mt19937 generator;
1307 };
1308 
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1309 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1310     if (!mIsCachingSupported) return;
1311     for (uint32_t i = 0; i < mNumModelCache; i++) {
1312         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1313                            [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
1314     }
1315 }
1316 
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1317 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1318     if (!mIsCachingSupported) return;
1319     for (uint32_t i = 0; i < mNumModelCache; i++) {
1320         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1321                            [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
1322     }
1323 }
1324 
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1325 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1326     if (!mIsCachingSupported) return;
1327     for (uint32_t i = 0; i < mNumDataCache; i++) {
1328         testCorruptedCache(ExpectedResult::NOT_CRASH,
1329                            [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
1330     }
1331 }
1332 
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1333 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1334     if (!mIsCachingSupported) return;
1335     for (uint32_t i = 0; i < mNumDataCache; i++) {
1336         testCorruptedCache(ExpectedResult::NOT_CRASH,
1337                            [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
1338     }
1339 }
1340 
TEST_P(CompilationCachingSecurityTest,WrongToken)1341 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1342     if (!mIsCachingSupported) return;
1343     testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1344         // Randomly flip one single bit in mToken.
1345         uint32_t ind =
1346                 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1347         mToken[ind] ^= (1U << getRandomInt(0, 7));
1348         *skip = false;
1349     });
1350 }
1351 
printCompilationCachingSecurityTest(const testing::TestParamInfo<CompilationCachingSecurityTestParam> & info)1352 std::string printCompilationCachingSecurityTest(
1353         const testing::TestParamInfo<CompilationCachingSecurityTestParam>& info) {
1354     const auto& [namedDevice, operandType, seed] = info.param;
1355     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1356     return gtestCompliantName(getName(namedDevice) + "_" + type + "_" + std::to_string(seed));
1357 }
1358 
1359 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1360                         testing::Combine(kNamedDeviceChoices, kOperandTypeChoices,
1361                                          testing::Range(0U, 10U)),
1362                         printCompilationCachingSecurityTest);
1363 
1364 }  // namespace android::hardware::neuralnetworks::V1_3::vts::functional
1365