1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include <android-base/logging.h>
20 #include <fcntl.h>
21 #include <ftw.h>
22 #include <gtest/gtest.h>
23 #include <hidlmemory/mapping.h>
24 #include <unistd.h>
25 
26 #include <cstdio>
27 #include <cstdlib>
28 #include <random>
29 #include <thread>
30 
31 #include "1.2/Callbacks.h"
32 #include "GeneratedTestHarness.h"
33 #include "MemoryUtils.h"
34 #include "TestHarness.h"
35 #include "VtsHalNeuralnetworks.h"
36 
37 // Forward declaration of the mobilenet generated test models in
38 // frameworks/ml/nn/runtime/test/generated/.
39 namespace generated_tests::mobilenet_224_gender_basic_fixed {
40 const test_helper::TestModel& get_test_model();
41 }  // namespace generated_tests::mobilenet_224_gender_basic_fixed
42 
43 namespace generated_tests::mobilenet_quantized {
44 const test_helper::TestModel& get_test_model();
45 }  // namespace generated_tests::mobilenet_quantized
46 
47 namespace android::hardware::neuralnetworks::V1_2::vts::functional {
48 
49 using namespace test_helper;
50 using implementation::PreparedModelCallback;
51 using V1_0::ErrorStatus;
52 using V1_1::ExecutionPreference;
53 
54 namespace float32_model {
55 
56 constexpr auto get_test_model = generated_tests::mobilenet_224_gender_basic_fixed::get_test_model;
57 
58 }  // namespace float32_model
59 
60 namespace quant8_model {
61 
62 constexpr auto get_test_model = generated_tests::mobilenet_quantized::get_test_model;
63 
64 }  // namespace quant8_model
65 
66 namespace {
67 
68 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
69 
70 // Creates cache handles based on provided file groups.
71 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,const std::vector<AccessMode> & mode,hidl_vec<hidl_handle> * handles)72 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
73                         const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
74     handles->resize(fileGroups.size());
75     for (uint32_t i = 0; i < fileGroups.size(); i++) {
76         std::vector<int> fds;
77         for (const auto& file : fileGroups[i]) {
78             int fd;
79             if (mode[i] == AccessMode::READ_ONLY) {
80                 fd = open(file.c_str(), O_RDONLY);
81             } else if (mode[i] == AccessMode::WRITE_ONLY) {
82                 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
83             } else if (mode[i] == AccessMode::READ_WRITE) {
84                 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
85             } else {
86                 FAIL();
87             }
88             ASSERT_GE(fd, 0);
89             fds.push_back(fd);
90         }
91         native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
92         ASSERT_NE(cacheNativeHandle, nullptr);
93         std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
94         (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
95     }
96 }
97 
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,AccessMode mode,hidl_vec<hidl_handle> * handles)98 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
99                         hidl_vec<hidl_handle>* handles) {
100     createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
101 }
102 
103 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
104 // For simplicity, activation scalar is shared. The second operand is not shared
105 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
106 // data locations in cache.
107 //
108 //                --------- activation --------
109 //                ↓      ↓      ↓             ↓
110 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
111 //                ↑      ↑      ↑             ↑
112 //               [1]    [1]    [1]           [1]
113 //
114 // This function assumes the operation is either ADD or MUL.
115 template <typename CppType, TestOperandType operandType>
createLargeTestModelImpl(TestOperationType op,uint32_t len)116 TestModel createLargeTestModelImpl(TestOperationType op, uint32_t len) {
117     EXPECT_TRUE(op == TestOperationType::ADD || op == TestOperationType::MUL);
118 
119     // Model operations and operands.
120     std::vector<TestOperation> operations(len);
121     std::vector<TestOperand> operands(len * 2 + 2);
122 
123     // The activation scalar, value = 0.
124     operands[0] = {
125             .type = TestOperandType::INT32,
126             .dimensions = {},
127             .numberOfConsumers = len,
128             .scale = 0.0f,
129             .zeroPoint = 0,
130             .lifetime = TestOperandLifeTime::CONSTANT_COPY,
131             .data = TestBuffer::createFromVector<int32_t>({0}),
132     };
133 
134     // The buffer value of the constant second operand. The logical value is always 1.0f.
135     CppType bufferValue;
136     // The scale of the first and second operand.
137     float scale1, scale2;
138     if (operandType == TestOperandType::TENSOR_FLOAT32) {
139         bufferValue = 1.0f;
140         scale1 = 0.0f;
141         scale2 = 0.0f;
142     } else if (op == TestOperationType::ADD) {
143         bufferValue = 1;
144         scale1 = 1.0f;
145         scale2 = 1.0f;
146     } else {
147         // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
148         // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
149         bufferValue = 2;
150         scale1 = 1.0f;
151         scale2 = 0.5f;
152     }
153 
154     for (uint32_t i = 0; i < len; i++) {
155         const uint32_t firstInputIndex = i * 2 + 1;
156         const uint32_t secondInputIndex = firstInputIndex + 1;
157         const uint32_t outputIndex = secondInputIndex + 1;
158 
159         // The first operation input.
160         operands[firstInputIndex] = {
161                 .type = operandType,
162                 .dimensions = {1},
163                 .numberOfConsumers = 1,
164                 .scale = scale1,
165                 .zeroPoint = 0,
166                 .lifetime = (i == 0 ? TestOperandLifeTime::MODEL_INPUT
167                                     : TestOperandLifeTime::TEMPORARY_VARIABLE),
168                 .data = (i == 0 ? TestBuffer::createFromVector<CppType>({1}) : TestBuffer()),
169         };
170 
171         // The second operation input, value = 1.
172         operands[secondInputIndex] = {
173                 .type = operandType,
174                 .dimensions = {1},
175                 .numberOfConsumers = 1,
176                 .scale = scale2,
177                 .zeroPoint = 0,
178                 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
179                 .data = TestBuffer::createFromVector<CppType>({bufferValue}),
180         };
181 
182         // The operation. All operations share the same activation scalar.
183         // The output operand is created as an input in the next iteration of the loop, in the case
184         // of all but the last member of the chain; and after the loop as a model output, in the
185         // case of the last member of the chain.
186         operations[i] = {
187                 .type = op,
188                 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
189                 .outputs = {outputIndex},
190         };
191     }
192 
193     // For TestOperationType::ADD, output = 1 + 1 * len = len + 1
194     // For TestOperationType::MUL, output = 1 * 1 ^ len = 1
195     CppType outputResult = static_cast<CppType>(op == TestOperationType::ADD ? len + 1u : 1u);
196 
197     // The model output.
198     operands.back() = {
199             .type = operandType,
200             .dimensions = {1},
201             .numberOfConsumers = 0,
202             .scale = scale1,
203             .zeroPoint = 0,
204             .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
205             .data = TestBuffer::createFromVector<CppType>({outputResult}),
206     };
207 
208     return {
209             .main = {.operands = std::move(operands),
210                      .operations = std::move(operations),
211                      .inputIndexes = {1},
212                      .outputIndexes = {len * 2 + 1}},
213             .isRelaxed = false,
214     };
215 }
216 
217 }  // namespace
218 
219 // Tag for the compilation caching tests.
220 class CompilationCachingTestBase : public testing::Test {
221   protected:
CompilationCachingTestBase(sp<IDevice> device,OperandType type)222     CompilationCachingTestBase(sp<IDevice> device, OperandType type)
223         : kDevice(std::move(device)), kOperandType(type) {}
224 
SetUp()225     void SetUp() override {
226         testing::Test::SetUp();
227         ASSERT_NE(kDevice.get(), nullptr);
228 
229         // Create cache directory. The cache directory and a temporary cache file is always created
230         // to test the behavior of prepareModelFromCache, even when caching is not supported.
231         char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
232         char* cacheDir = mkdtemp(cacheDirTemp);
233         ASSERT_NE(cacheDir, nullptr);
234         mCacheDir = cacheDir;
235         mCacheDir.push_back('/');
236 
237         Return<void> ret = kDevice->getNumberOfCacheFilesNeeded(
238                 [this](ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
239                     EXPECT_EQ(ErrorStatus::NONE, status);
240                     mNumModelCache = numModelCache;
241                     mNumDataCache = numDataCache;
242                 });
243         EXPECT_TRUE(ret.isOk());
244         mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
245 
246         // Create empty cache files.
247         mTmpCache = mCacheDir + "tmp";
248         for (uint32_t i = 0; i < mNumModelCache; i++) {
249             mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
250         }
251         for (uint32_t i = 0; i < mNumDataCache; i++) {
252             mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
253         }
254         // Dummy handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
255         hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
256         createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
257         createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
258         createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
259 
260         if (!mIsCachingSupported) {
261             LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
262                          "support compilation caching.";
263             std::cout << "[          ]   Early termination of test because vendor service does not "
264                          "support compilation caching."
265                       << std::endl;
266         }
267     }
268 
TearDown()269     void TearDown() override {
270         // If the test passes, remove the tmp directory.  Otherwise, keep it for debugging purposes.
271         if (!testing::Test::HasFailure()) {
272             // Recursively remove the cache directory specified by mCacheDir.
273             auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
274                 return remove(entry);
275             };
276             nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
277         }
278         testing::Test::TearDown();
279     }
280 
281     // Model and examples creators. According to kOperandType, the following methods will return
282     // either float32 model/examples or the quant8 variant.
createTestModel()283     TestModel createTestModel() {
284         if (kOperandType == OperandType::TENSOR_FLOAT32) {
285             return float32_model::get_test_model();
286         } else {
287             return quant8_model::get_test_model();
288         }
289     }
290 
createLargeTestModel(OperationType op,uint32_t len)291     TestModel createLargeTestModel(OperationType op, uint32_t len) {
292         if (kOperandType == OperandType::TENSOR_FLOAT32) {
293             return createLargeTestModelImpl<float, TestOperandType::TENSOR_FLOAT32>(
294                     static_cast<TestOperationType>(op), len);
295         } else {
296             return createLargeTestModelImpl<uint8_t, TestOperandType::TENSOR_QUANT8_ASYMM>(
297                     static_cast<TestOperationType>(op), len);
298         }
299     }
300 
301     // See if the service can handle the model.
isModelFullySupported(const Model & model)302     bool isModelFullySupported(const Model& model) {
303         bool fullySupportsModel = false;
304         Return<void> supportedCall = kDevice->getSupportedOperations_1_2(
305                 model,
306                 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
307                     ASSERT_EQ(ErrorStatus::NONE, status);
308                     ASSERT_EQ(supported.size(), model.operations.size());
309                     fullySupportsModel = std::all_of(supported.begin(), supported.end(),
310                                                      [](bool valid) { return valid; });
311                 });
312         EXPECT_TRUE(supportedCall.isOk());
313         return fullySupportsModel;
314     }
315 
saveModelToCache(const Model & model,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel=nullptr)316     void saveModelToCache(const Model& model, const hidl_vec<hidl_handle>& modelCache,
317                           const hidl_vec<hidl_handle>& dataCache,
318                           sp<IPreparedModel>* preparedModel = nullptr) {
319         if (preparedModel != nullptr) *preparedModel = nullptr;
320 
321         // Launch prepare model.
322         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
323         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
324         Return<ErrorStatus> prepareLaunchStatus =
325                 kDevice->prepareModel_1_2(model, ExecutionPreference::FAST_SINGLE_ANSWER,
326                                           modelCache, dataCache, cacheToken, preparedModelCallback);
327         ASSERT_TRUE(prepareLaunchStatus.isOk());
328         ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
329 
330         // Retrieve prepared model.
331         preparedModelCallback->wait();
332         ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
333         if (preparedModel != nullptr) {
334             *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
335                                      .withDefault(nullptr);
336         }
337     }
338 
checkEarlyTermination(ErrorStatus status)339     bool checkEarlyTermination(ErrorStatus status) {
340         if (status == ErrorStatus::GENERAL_FAILURE) {
341             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
342                          "save the prepared model that it does not support.";
343             std::cout << "[          ]   Early termination of test because vendor service cannot "
344                          "save the prepared model that it does not support."
345                       << std::endl;
346             return true;
347         }
348         return false;
349     }
350 
checkEarlyTermination(const Model & model)351     bool checkEarlyTermination(const Model& model) {
352         if (!isModelFullySupported(model)) {
353             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
354                          "prepare model that it does not support.";
355             std::cout << "[          ]   Early termination of test because vendor service cannot "
356                          "prepare model that it does not support."
357                       << std::endl;
358             return true;
359         }
360         return false;
361     }
362 
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel,ErrorStatus * status)363     void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
364                                const hidl_vec<hidl_handle>& dataCache,
365                                sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
366         // Launch prepare model from cache.
367         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
368         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
369         Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModelFromCache(
370                 modelCache, dataCache, cacheToken, preparedModelCallback);
371         ASSERT_TRUE(prepareLaunchStatus.isOk());
372         if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
373             *preparedModel = nullptr;
374             *status = static_cast<ErrorStatus>(prepareLaunchStatus);
375             return;
376         }
377 
378         // Retrieve prepared model.
379         preparedModelCallback->wait();
380         *status = preparedModelCallback->getStatus();
381         *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
382                                  .withDefault(nullptr);
383     }
384 
385     // Absolute path to the temporary cache directory.
386     std::string mCacheDir;
387 
388     // Groups of file paths for model and data cache in the tmp cache directory, initialized with
389     // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
390     // and the inner vector is for fds held by each handle.
391     std::vector<std::vector<std::string>> mModelCache;
392     std::vector<std::vector<std::string>> mDataCache;
393 
394     // A separate temporary file path in the tmp cache directory.
395     std::string mTmpCache;
396 
397     uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
398     uint32_t mNumModelCache;
399     uint32_t mNumDataCache;
400     uint32_t mIsCachingSupported;
401 
402     const sp<IDevice> kDevice;
403     // The primary data type of the testModel.
404     const OperandType kOperandType;
405 };
406 
407 using CompilationCachingTestParam = std::tuple<NamedDevice, OperandType>;
408 
409 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
410 // pass running with float32 models and the second pass running with quant8 models.
411 class CompilationCachingTest : public CompilationCachingTestBase,
412                                public testing::WithParamInterface<CompilationCachingTestParam> {
413   protected:
CompilationCachingTest()414     CompilationCachingTest()
415         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
416                                      std::get<OperandType>(GetParam())) {}
417 };
418 
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)419 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
420     // Create test HIDL model and compile.
421     const TestModel& testModel = createTestModel();
422     const Model model = createModel(testModel);
423     if (checkEarlyTermination(model)) return;
424     sp<IPreparedModel> preparedModel = nullptr;
425 
426     // Save the compilation to cache.
427     {
428         hidl_vec<hidl_handle> modelCache, dataCache;
429         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
430         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
431         saveModelToCache(model, modelCache, dataCache);
432     }
433 
434     // Retrieve preparedModel from cache.
435     {
436         preparedModel = nullptr;
437         ErrorStatus status;
438         hidl_vec<hidl_handle> modelCache, dataCache;
439         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
440         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
441         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
442         if (!mIsCachingSupported) {
443             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
444             ASSERT_EQ(preparedModel, nullptr);
445             return;
446         } else if (checkEarlyTermination(status)) {
447             ASSERT_EQ(preparedModel, nullptr);
448             return;
449         } else {
450             ASSERT_EQ(status, ErrorStatus::NONE);
451             ASSERT_NE(preparedModel, nullptr);
452         }
453     }
454 
455     // Execute and verify results.
456     EvaluatePreparedModel(preparedModel, testModel,
457                           /*testDynamicOutputShape=*/false);
458 }
459 
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)460 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
461     // Create test HIDL model and compile.
462     const TestModel& testModel = createTestModel();
463     const Model model = createModel(testModel);
464     if (checkEarlyTermination(model)) return;
465     sp<IPreparedModel> preparedModel = nullptr;
466 
467     // Save the compilation to cache.
468     {
469         hidl_vec<hidl_handle> modelCache, dataCache;
470         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
471         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
472         uint8_t dummyBytes[] = {0, 0};
473         // Write a dummy integer to the cache.
474         // The driver should be able to handle non-empty cache and non-zero fd offset.
475         for (uint32_t i = 0; i < modelCache.size(); i++) {
476             ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &dummyBytes,
477                             sizeof(dummyBytes)),
478                       sizeof(dummyBytes));
479         }
480         for (uint32_t i = 0; i < dataCache.size(); i++) {
481             ASSERT_EQ(
482                     write(dataCache[i].getNativeHandle()->data[0], &dummyBytes, sizeof(dummyBytes)),
483                     sizeof(dummyBytes));
484         }
485         saveModelToCache(model, modelCache, dataCache);
486     }
487 
488     // Retrieve preparedModel from cache.
489     {
490         preparedModel = nullptr;
491         ErrorStatus status;
492         hidl_vec<hidl_handle> modelCache, dataCache;
493         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
494         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
495         uint8_t dummyByte = 0;
496         // Advance the offset of each handle by one byte.
497         // The driver should be able to handle non-zero fd offset.
498         for (uint32_t i = 0; i < modelCache.size(); i++) {
499             ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
500         }
501         for (uint32_t i = 0; i < dataCache.size(); i++) {
502             ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
503         }
504         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
505         if (!mIsCachingSupported) {
506             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
507             ASSERT_EQ(preparedModel, nullptr);
508             return;
509         } else if (checkEarlyTermination(status)) {
510             ASSERT_EQ(preparedModel, nullptr);
511             return;
512         } else {
513             ASSERT_EQ(status, ErrorStatus::NONE);
514             ASSERT_NE(preparedModel, nullptr);
515         }
516     }
517 
518     // Execute and verify results.
519     EvaluatePreparedModel(preparedModel, testModel,
520                           /*testDynamicOutputShape=*/false);
521 }
522 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)523 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
524     // Create test HIDL model and compile.
525     const TestModel& testModel = createTestModel();
526     const Model model = createModel(testModel);
527     if (checkEarlyTermination(model)) return;
528 
529     // Test with number of model cache files greater than mNumModelCache.
530     {
531         hidl_vec<hidl_handle> modelCache, dataCache;
532         // Pass an additional cache file for model cache.
533         mModelCache.push_back({mTmpCache});
534         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
535         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
536         mModelCache.pop_back();
537         sp<IPreparedModel> preparedModel = nullptr;
538         saveModelToCache(model, modelCache, dataCache, &preparedModel);
539         ASSERT_NE(preparedModel, nullptr);
540         // Execute and verify results.
541         EvaluatePreparedModel(preparedModel, testModel,
542                               /*testDynamicOutputShape=*/false);
543         // Check if prepareModelFromCache fails.
544         preparedModel = nullptr;
545         ErrorStatus status;
546         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
547         if (status != ErrorStatus::INVALID_ARGUMENT) {
548             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
549         }
550         ASSERT_EQ(preparedModel, nullptr);
551     }
552 
553     // Test with number of model cache files smaller than mNumModelCache.
554     if (mModelCache.size() > 0) {
555         hidl_vec<hidl_handle> modelCache, dataCache;
556         // Pop out the last cache file.
557         auto tmp = mModelCache.back();
558         mModelCache.pop_back();
559         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
560         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
561         mModelCache.push_back(tmp);
562         sp<IPreparedModel> preparedModel = nullptr;
563         saveModelToCache(model, modelCache, dataCache, &preparedModel);
564         ASSERT_NE(preparedModel, nullptr);
565         // Execute and verify results.
566         EvaluatePreparedModel(preparedModel, testModel,
567                               /*testDynamicOutputShape=*/false);
568         // Check if prepareModelFromCache fails.
569         preparedModel = nullptr;
570         ErrorStatus status;
571         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
572         if (status != ErrorStatus::INVALID_ARGUMENT) {
573             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
574         }
575         ASSERT_EQ(preparedModel, nullptr);
576     }
577 
578     // Test with number of data cache files greater than mNumDataCache.
579     {
580         hidl_vec<hidl_handle> modelCache, dataCache;
581         // Pass an additional cache file for data cache.
582         mDataCache.push_back({mTmpCache});
583         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
584         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
585         mDataCache.pop_back();
586         sp<IPreparedModel> preparedModel = nullptr;
587         saveModelToCache(model, modelCache, dataCache, &preparedModel);
588         ASSERT_NE(preparedModel, nullptr);
589         // Execute and verify results.
590         EvaluatePreparedModel(preparedModel, testModel,
591                               /*testDynamicOutputShape=*/false);
592         // Check if prepareModelFromCache fails.
593         preparedModel = nullptr;
594         ErrorStatus status;
595         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
596         if (status != ErrorStatus::INVALID_ARGUMENT) {
597             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
598         }
599         ASSERT_EQ(preparedModel, nullptr);
600     }
601 
602     // Test with number of data cache files smaller than mNumDataCache.
603     if (mDataCache.size() > 0) {
604         hidl_vec<hidl_handle> modelCache, dataCache;
605         // Pop out the last cache file.
606         auto tmp = mDataCache.back();
607         mDataCache.pop_back();
608         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
609         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
610         mDataCache.push_back(tmp);
611         sp<IPreparedModel> preparedModel = nullptr;
612         saveModelToCache(model, modelCache, dataCache, &preparedModel);
613         ASSERT_NE(preparedModel, nullptr);
614         // Execute and verify results.
615         EvaluatePreparedModel(preparedModel, testModel,
616                               /*testDynamicOutputShape=*/false);
617         // Check if prepareModelFromCache fails.
618         preparedModel = nullptr;
619         ErrorStatus status;
620         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
621         if (status != ErrorStatus::INVALID_ARGUMENT) {
622             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
623         }
624         ASSERT_EQ(preparedModel, nullptr);
625     }
626 }
627 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)628 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
629     // Create test HIDL model and compile.
630     const TestModel& testModel = createTestModel();
631     const Model model = createModel(testModel);
632     if (checkEarlyTermination(model)) return;
633 
634     // Save the compilation to cache.
635     {
636         hidl_vec<hidl_handle> modelCache, dataCache;
637         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
638         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
639         saveModelToCache(model, modelCache, dataCache);
640     }
641 
642     // Test with number of model cache files greater than mNumModelCache.
643     {
644         sp<IPreparedModel> preparedModel = nullptr;
645         ErrorStatus status;
646         hidl_vec<hidl_handle> modelCache, dataCache;
647         mModelCache.push_back({mTmpCache});
648         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
649         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
650         mModelCache.pop_back();
651         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
652         if (status != ErrorStatus::GENERAL_FAILURE) {
653             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
654         }
655         ASSERT_EQ(preparedModel, nullptr);
656     }
657 
658     // Test with number of model cache files smaller than mNumModelCache.
659     if (mModelCache.size() > 0) {
660         sp<IPreparedModel> preparedModel = nullptr;
661         ErrorStatus status;
662         hidl_vec<hidl_handle> modelCache, dataCache;
663         auto tmp = mModelCache.back();
664         mModelCache.pop_back();
665         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
666         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
667         mModelCache.push_back(tmp);
668         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
669         if (status != ErrorStatus::GENERAL_FAILURE) {
670             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
671         }
672         ASSERT_EQ(preparedModel, nullptr);
673     }
674 
675     // Test with number of data cache files greater than mNumDataCache.
676     {
677         sp<IPreparedModel> preparedModel = nullptr;
678         ErrorStatus status;
679         hidl_vec<hidl_handle> modelCache, dataCache;
680         mDataCache.push_back({mTmpCache});
681         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
682         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
683         mDataCache.pop_back();
684         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
685         if (status != ErrorStatus::GENERAL_FAILURE) {
686             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
687         }
688         ASSERT_EQ(preparedModel, nullptr);
689     }
690 
691     // Test with number of data cache files smaller than mNumDataCache.
692     if (mDataCache.size() > 0) {
693         sp<IPreparedModel> preparedModel = nullptr;
694         ErrorStatus status;
695         hidl_vec<hidl_handle> modelCache, dataCache;
696         auto tmp = mDataCache.back();
697         mDataCache.pop_back();
698         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
699         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
700         mDataCache.push_back(tmp);
701         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
702         if (status != ErrorStatus::GENERAL_FAILURE) {
703             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
704         }
705         ASSERT_EQ(preparedModel, nullptr);
706     }
707 }
708 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumFd)709 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
710     // Create test HIDL model and compile.
711     const TestModel& testModel = createTestModel();
712     const Model model = createModel(testModel);
713     if (checkEarlyTermination(model)) return;
714 
715     // Go through each handle in model cache, test with NumFd greater than 1.
716     for (uint32_t i = 0; i < mNumModelCache; i++) {
717         hidl_vec<hidl_handle> modelCache, dataCache;
718         // Pass an invalid number of fds for handle i.
719         mModelCache[i].push_back(mTmpCache);
720         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
721         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
722         mModelCache[i].pop_back();
723         sp<IPreparedModel> preparedModel = nullptr;
724         saveModelToCache(model, modelCache, dataCache, &preparedModel);
725         ASSERT_NE(preparedModel, nullptr);
726         // Execute and verify results.
727         EvaluatePreparedModel(preparedModel, testModel,
728                               /*testDynamicOutputShape=*/false);
729         // Check if prepareModelFromCache fails.
730         preparedModel = nullptr;
731         ErrorStatus status;
732         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
733         if (status != ErrorStatus::INVALID_ARGUMENT) {
734             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
735         }
736         ASSERT_EQ(preparedModel, nullptr);
737     }
738 
739     // Go through each handle in model cache, test with NumFd equal to 0.
740     for (uint32_t i = 0; i < mNumModelCache; i++) {
741         hidl_vec<hidl_handle> modelCache, dataCache;
742         // Pass an invalid number of fds for handle i.
743         auto tmp = mModelCache[i].back();
744         mModelCache[i].pop_back();
745         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
746         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
747         mModelCache[i].push_back(tmp);
748         sp<IPreparedModel> preparedModel = nullptr;
749         saveModelToCache(model, modelCache, dataCache, &preparedModel);
750         ASSERT_NE(preparedModel, nullptr);
751         // Execute and verify results.
752         EvaluatePreparedModel(preparedModel, testModel,
753                               /*testDynamicOutputShape=*/false);
754         // Check if prepareModelFromCache fails.
755         preparedModel = nullptr;
756         ErrorStatus status;
757         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
758         if (status != ErrorStatus::INVALID_ARGUMENT) {
759             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
760         }
761         ASSERT_EQ(preparedModel, nullptr);
762     }
763 
764     // Go through each handle in data cache, test with NumFd greater than 1.
765     for (uint32_t i = 0; i < mNumDataCache; i++) {
766         hidl_vec<hidl_handle> modelCache, dataCache;
767         // Pass an invalid number of fds for handle i.
768         mDataCache[i].push_back(mTmpCache);
769         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
770         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
771         mDataCache[i].pop_back();
772         sp<IPreparedModel> preparedModel = nullptr;
773         saveModelToCache(model, modelCache, dataCache, &preparedModel);
774         ASSERT_NE(preparedModel, nullptr);
775         // Execute and verify results.
776         EvaluatePreparedModel(preparedModel, testModel,
777                               /*testDynamicOutputShape=*/false);
778         // Check if prepareModelFromCache fails.
779         preparedModel = nullptr;
780         ErrorStatus status;
781         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
782         if (status != ErrorStatus::INVALID_ARGUMENT) {
783             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
784         }
785         ASSERT_EQ(preparedModel, nullptr);
786     }
787 
788     // Go through each handle in data cache, test with NumFd equal to 0.
789     for (uint32_t i = 0; i < mNumDataCache; i++) {
790         hidl_vec<hidl_handle> modelCache, dataCache;
791         // Pass an invalid number of fds for handle i.
792         auto tmp = mDataCache[i].back();
793         mDataCache[i].pop_back();
794         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
795         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
796         mDataCache[i].push_back(tmp);
797         sp<IPreparedModel> preparedModel = nullptr;
798         saveModelToCache(model, modelCache, dataCache, &preparedModel);
799         ASSERT_NE(preparedModel, nullptr);
800         // Execute and verify results.
801         EvaluatePreparedModel(preparedModel, testModel,
802                               /*testDynamicOutputShape=*/false);
803         // Check if prepareModelFromCache fails.
804         preparedModel = nullptr;
805         ErrorStatus status;
806         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
807         if (status != ErrorStatus::INVALID_ARGUMENT) {
808             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
809         }
810         ASSERT_EQ(preparedModel, nullptr);
811     }
812 }
813 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumFd)814 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
815     // Create test HIDL model and compile.
816     const TestModel& testModel = createTestModel();
817     const Model model = createModel(testModel);
818     if (checkEarlyTermination(model)) return;
819 
820     // Save the compilation to cache.
821     {
822         hidl_vec<hidl_handle> modelCache, dataCache;
823         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
824         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
825         saveModelToCache(model, modelCache, dataCache);
826     }
827 
828     // Go through each handle in model cache, test with NumFd greater than 1.
829     for (uint32_t i = 0; i < mNumModelCache; i++) {
830         sp<IPreparedModel> preparedModel = nullptr;
831         ErrorStatus status;
832         hidl_vec<hidl_handle> modelCache, dataCache;
833         mModelCache[i].push_back(mTmpCache);
834         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
835         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
836         mModelCache[i].pop_back();
837         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
838         if (status != ErrorStatus::GENERAL_FAILURE) {
839             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
840         }
841         ASSERT_EQ(preparedModel, nullptr);
842     }
843 
844     // Go through each handle in model cache, test with NumFd equal to 0.
845     for (uint32_t i = 0; i < mNumModelCache; i++) {
846         sp<IPreparedModel> preparedModel = nullptr;
847         ErrorStatus status;
848         hidl_vec<hidl_handle> modelCache, dataCache;
849         auto tmp = mModelCache[i].back();
850         mModelCache[i].pop_back();
851         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
852         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
853         mModelCache[i].push_back(tmp);
854         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
855         if (status != ErrorStatus::GENERAL_FAILURE) {
856             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
857         }
858         ASSERT_EQ(preparedModel, nullptr);
859     }
860 
861     // Go through each handle in data cache, test with NumFd greater than 1.
862     for (uint32_t i = 0; i < mNumDataCache; i++) {
863         sp<IPreparedModel> preparedModel = nullptr;
864         ErrorStatus status;
865         hidl_vec<hidl_handle> modelCache, dataCache;
866         mDataCache[i].push_back(mTmpCache);
867         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
868         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
869         mDataCache[i].pop_back();
870         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
871         if (status != ErrorStatus::GENERAL_FAILURE) {
872             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
873         }
874         ASSERT_EQ(preparedModel, nullptr);
875     }
876 
877     // Go through each handle in data cache, test with NumFd equal to 0.
878     for (uint32_t i = 0; i < mNumDataCache; i++) {
879         sp<IPreparedModel> preparedModel = nullptr;
880         ErrorStatus status;
881         hidl_vec<hidl_handle> modelCache, dataCache;
882         auto tmp = mDataCache[i].back();
883         mDataCache[i].pop_back();
884         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
885         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
886         mDataCache[i].push_back(tmp);
887         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
888         if (status != ErrorStatus::GENERAL_FAILURE) {
889             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
890         }
891         ASSERT_EQ(preparedModel, nullptr);
892     }
893 }
894 
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)895 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
896     // Create test HIDL model and compile.
897     const TestModel& testModel = createTestModel();
898     const Model model = createModel(testModel);
899     if (checkEarlyTermination(model)) return;
900     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
901     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
902 
903     // Go through each handle in model cache, test with invalid access mode.
904     for (uint32_t i = 0; i < mNumModelCache; i++) {
905         hidl_vec<hidl_handle> modelCache, dataCache;
906         modelCacheMode[i] = AccessMode::READ_ONLY;
907         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
908         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
909         modelCacheMode[i] = AccessMode::READ_WRITE;
910         sp<IPreparedModel> preparedModel = nullptr;
911         saveModelToCache(model, modelCache, dataCache, &preparedModel);
912         ASSERT_NE(preparedModel, nullptr);
913         // Execute and verify results.
914         EvaluatePreparedModel(preparedModel, testModel,
915                               /*testDynamicOutputShape=*/false);
916         // Check if prepareModelFromCache fails.
917         preparedModel = nullptr;
918         ErrorStatus status;
919         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
920         if (status != ErrorStatus::INVALID_ARGUMENT) {
921             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
922         }
923         ASSERT_EQ(preparedModel, nullptr);
924     }
925 
926     // Go through each handle in data cache, test with invalid access mode.
927     for (uint32_t i = 0; i < mNumDataCache; i++) {
928         hidl_vec<hidl_handle> modelCache, dataCache;
929         dataCacheMode[i] = AccessMode::READ_ONLY;
930         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
931         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
932         dataCacheMode[i] = AccessMode::READ_WRITE;
933         sp<IPreparedModel> preparedModel = nullptr;
934         saveModelToCache(model, modelCache, dataCache, &preparedModel);
935         ASSERT_NE(preparedModel, nullptr);
936         // Execute and verify results.
937         EvaluatePreparedModel(preparedModel, testModel,
938                               /*testDynamicOutputShape=*/false);
939         // Check if prepareModelFromCache fails.
940         preparedModel = nullptr;
941         ErrorStatus status;
942         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
943         if (status != ErrorStatus::INVALID_ARGUMENT) {
944             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
945         }
946         ASSERT_EQ(preparedModel, nullptr);
947     }
948 }
949 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)950 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
951     // Create test HIDL model and compile.
952     const TestModel& testModel = createTestModel();
953     const Model model = createModel(testModel);
954     if (checkEarlyTermination(model)) return;
955     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
956     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
957 
958     // Save the compilation to cache.
959     {
960         hidl_vec<hidl_handle> modelCache, dataCache;
961         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
962         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
963         saveModelToCache(model, modelCache, dataCache);
964     }
965 
966     // Go through each handle in model cache, test with invalid access mode.
967     for (uint32_t i = 0; i < mNumModelCache; i++) {
968         sp<IPreparedModel> preparedModel = nullptr;
969         ErrorStatus status;
970         hidl_vec<hidl_handle> modelCache, dataCache;
971         modelCacheMode[i] = AccessMode::WRITE_ONLY;
972         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
973         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
974         modelCacheMode[i] = AccessMode::READ_WRITE;
975         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
976         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
977         ASSERT_EQ(preparedModel, nullptr);
978     }
979 
980     // Go through each handle in data cache, test with invalid access mode.
981     for (uint32_t i = 0; i < mNumDataCache; i++) {
982         sp<IPreparedModel> preparedModel = nullptr;
983         ErrorStatus status;
984         hidl_vec<hidl_handle> modelCache, dataCache;
985         dataCacheMode[i] = AccessMode::WRITE_ONLY;
986         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
987         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
988         dataCacheMode[i] = AccessMode::READ_WRITE;
989         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
990         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
991         ASSERT_EQ(preparedModel, nullptr);
992     }
993 }
994 
995 // Copy file contents between file groups.
996 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
997 // The outer vector sizes must match and the inner vectors must have size = 1.
copyCacheFiles(const std::vector<std::vector<std::string>> & from,const std::vector<std::vector<std::string>> & to)998 static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
999                            const std::vector<std::vector<std::string>>& to) {
1000     constexpr size_t kBufferSize = 1000000;
1001     uint8_t buffer[kBufferSize];
1002 
1003     ASSERT_EQ(from.size(), to.size());
1004     for (uint32_t i = 0; i < from.size(); i++) {
1005         ASSERT_EQ(from[i].size(), 1u);
1006         ASSERT_EQ(to[i].size(), 1u);
1007         int fromFd = open(from[i][0].c_str(), O_RDONLY);
1008         int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
1009         ASSERT_GE(fromFd, 0);
1010         ASSERT_GE(toFd, 0);
1011 
1012         ssize_t readBytes;
1013         while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
1014             ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
1015         }
1016         ASSERT_GE(readBytes, 0);
1017 
1018         close(fromFd);
1019         close(toFd);
1020     }
1021 }
1022 
1023 // Number of operations in the large test model.
1024 constexpr uint32_t kLargeModelSize = 100;
1025 constexpr uint32_t kNumIterationsTOCTOU = 100;
1026 
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)1027 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
1028     if (!mIsCachingSupported) return;
1029 
1030     // Create test models and check if fully supported by the service.
1031     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1032     const Model modelMul = createModel(testModelMul);
1033     if (checkEarlyTermination(modelMul)) return;
1034     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1035     const Model modelAdd = createModel(testModelAdd);
1036     if (checkEarlyTermination(modelAdd)) return;
1037 
1038     // Save the modelMul compilation to cache.
1039     auto modelCacheMul = mModelCache;
1040     for (auto& cache : modelCacheMul) {
1041         cache[0].append("_mul");
1042     }
1043     {
1044         hidl_vec<hidl_handle> modelCache, dataCache;
1045         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1046         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1047         saveModelToCache(modelMul, modelCache, dataCache);
1048     }
1049 
1050     // Use a different token for modelAdd.
1051     mToken[0]++;
1052 
1053     // This test is probabilistic, so we run it multiple times.
1054     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1055         // Save the modelAdd compilation to cache.
1056         {
1057             hidl_vec<hidl_handle> modelCache, dataCache;
1058             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1059             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1060 
1061             // Spawn a thread to copy the cache content concurrently while saving to cache.
1062             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1063             saveModelToCache(modelAdd, modelCache, dataCache);
1064             thread.join();
1065         }
1066 
1067         // Retrieve preparedModel from cache.
1068         {
1069             sp<IPreparedModel> preparedModel = nullptr;
1070             ErrorStatus status;
1071             hidl_vec<hidl_handle> modelCache, dataCache;
1072             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1073             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1074             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1075 
1076             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1077             // the prepared model must be executed with the correct result and not crash.
1078             if (status != ErrorStatus::NONE) {
1079                 ASSERT_EQ(preparedModel, nullptr);
1080             } else {
1081                 ASSERT_NE(preparedModel, nullptr);
1082                 EvaluatePreparedModel(preparedModel, testModelAdd,
1083                                       /*testDynamicOutputShape=*/false);
1084             }
1085         }
1086     }
1087 }
1088 
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)1089 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
1090     if (!mIsCachingSupported) return;
1091 
1092     // Create test models and check if fully supported by the service.
1093     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1094     const Model modelMul = createModel(testModelMul);
1095     if (checkEarlyTermination(modelMul)) return;
1096     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1097     const Model modelAdd = createModel(testModelAdd);
1098     if (checkEarlyTermination(modelAdd)) return;
1099 
1100     // Save the modelMul compilation to cache.
1101     auto modelCacheMul = mModelCache;
1102     for (auto& cache : modelCacheMul) {
1103         cache[0].append("_mul");
1104     }
1105     {
1106         hidl_vec<hidl_handle> modelCache, dataCache;
1107         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1108         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1109         saveModelToCache(modelMul, modelCache, dataCache);
1110     }
1111 
1112     // Use a different token for modelAdd.
1113     mToken[0]++;
1114 
1115     // This test is probabilistic, so we run it multiple times.
1116     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1117         // Save the modelAdd compilation to cache.
1118         {
1119             hidl_vec<hidl_handle> modelCache, dataCache;
1120             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1121             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1122             saveModelToCache(modelAdd, modelCache, dataCache);
1123         }
1124 
1125         // Retrieve preparedModel from cache.
1126         {
1127             sp<IPreparedModel> preparedModel = nullptr;
1128             ErrorStatus status;
1129             hidl_vec<hidl_handle> modelCache, dataCache;
1130             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1131             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1132 
1133             // Spawn a thread to copy the cache content concurrently while preparing from cache.
1134             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1135             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1136             thread.join();
1137 
1138             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1139             // the prepared model must be executed with the correct result and not crash.
1140             if (status != ErrorStatus::NONE) {
1141                 ASSERT_EQ(preparedModel, nullptr);
1142             } else {
1143                 ASSERT_NE(preparedModel, nullptr);
1144                 EvaluatePreparedModel(preparedModel, testModelAdd,
1145                                       /*testDynamicOutputShape=*/false);
1146             }
1147         }
1148     }
1149 }
1150 
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)1151 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
1152     if (!mIsCachingSupported) return;
1153 
1154     // Create test models and check if fully supported by the service.
1155     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1156     const Model modelMul = createModel(testModelMul);
1157     if (checkEarlyTermination(modelMul)) return;
1158     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1159     const Model modelAdd = createModel(testModelAdd);
1160     if (checkEarlyTermination(modelAdd)) return;
1161 
1162     // Save the modelMul compilation to cache.
1163     auto modelCacheMul = mModelCache;
1164     for (auto& cache : modelCacheMul) {
1165         cache[0].append("_mul");
1166     }
1167     {
1168         hidl_vec<hidl_handle> modelCache, dataCache;
1169         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1170         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1171         saveModelToCache(modelMul, modelCache, dataCache);
1172     }
1173 
1174     // Use a different token for modelAdd.
1175     mToken[0]++;
1176 
1177     // Save the modelAdd compilation to cache.
1178     {
1179         hidl_vec<hidl_handle> modelCache, dataCache;
1180         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1181         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1182         saveModelToCache(modelAdd, modelCache, dataCache);
1183     }
1184 
1185     // Replace the model cache of modelAdd with modelMul.
1186     copyCacheFiles(modelCacheMul, mModelCache);
1187 
1188     // Retrieve the preparedModel from cache, expect failure.
1189     {
1190         sp<IPreparedModel> preparedModel = nullptr;
1191         ErrorStatus status;
1192         hidl_vec<hidl_handle> modelCache, dataCache;
1193         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1194         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1195         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1196         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1197         ASSERT_EQ(preparedModel, nullptr);
1198     }
1199 }
1200 
1201 static const auto kNamedDeviceChoices = testing::ValuesIn(getNamedDevices());
1202 static const auto kOperandTypeChoices =
1203         testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1204 
printCompilationCachingTest(const testing::TestParamInfo<CompilationCachingTestParam> & info)1205 std::string printCompilationCachingTest(
1206         const testing::TestParamInfo<CompilationCachingTestParam>& info) {
1207     const auto& [namedDevice, operandType] = info.param;
1208     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1209     return gtestCompliantName(getName(namedDevice) + "_" + type);
1210 }
1211 
1212 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest,
1213                         testing::Combine(kNamedDeviceChoices, kOperandTypeChoices),
1214                         printCompilationCachingTest);
1215 
1216 using CompilationCachingSecurityTestParam = std::tuple<NamedDevice, OperandType, uint32_t>;
1217 
1218 class CompilationCachingSecurityTest
1219     : public CompilationCachingTestBase,
1220       public testing::WithParamInterface<CompilationCachingSecurityTestParam> {
1221   protected:
CompilationCachingSecurityTest()1222     CompilationCachingSecurityTest()
1223         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
1224                                      std::get<OperandType>(GetParam())) {}
1225 
SetUp()1226     void SetUp() {
1227         CompilationCachingTestBase::SetUp();
1228         generator.seed(kSeed);
1229     }
1230 
1231     // Get a random integer within a closed range [lower, upper].
1232     template <typename T>
getRandomInt(T lower,T upper)1233     T getRandomInt(T lower, T upper) {
1234         std::uniform_int_distribution<T> dis(lower, upper);
1235         return dis(generator);
1236     }
1237 
1238     // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1239     void flipOneBitOfCache(const std::string& filename, bool* skip) {
1240         FILE* pFile = fopen(filename.c_str(), "r+");
1241         ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1242         long int fileSize = ftell(pFile);
1243         if (fileSize == 0) {
1244             fclose(pFile);
1245             *skip = true;
1246             return;
1247         }
1248         ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1249         int readByte = fgetc(pFile);
1250         ASSERT_NE(readByte, EOF);
1251         ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1252         ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1253         fclose(pFile);
1254         *skip = false;
1255     }
1256 
1257     // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1258     void appendBytesToCache(const std::string& filename, bool* skip) {
1259         FILE* pFile = fopen(filename.c_str(), "a");
1260         uint32_t appendLength = getRandomInt(1, 256);
1261         for (uint32_t i = 0; i < appendLength; i++) {
1262             ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
1263         }
1264         fclose(pFile);
1265         *skip = false;
1266     }
1267 
1268     enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1269 
1270     // Test if the driver behaves as expected when given corrupted cache or token.
1271     // The modifier will be invoked after save to cache but before prepare from cache.
1272     // The modifier accepts one pointer argument "skip" as the returning value, indicating
1273     // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1274     void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1275         const TestModel& testModel = createTestModel();
1276         const Model model = createModel(testModel);
1277         if (checkEarlyTermination(model)) return;
1278 
1279         // Save the compilation to cache.
1280         {
1281             hidl_vec<hidl_handle> modelCache, dataCache;
1282             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1283             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1284             saveModelToCache(model, modelCache, dataCache);
1285         }
1286 
1287         bool skip = false;
1288         modifier(&skip);
1289         if (skip) return;
1290 
1291         // Retrieve preparedModel from cache.
1292         {
1293             sp<IPreparedModel> preparedModel = nullptr;
1294             ErrorStatus status;
1295             hidl_vec<hidl_handle> modelCache, dataCache;
1296             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1297             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1298             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1299 
1300             switch (expected) {
1301                 case ExpectedResult::GENERAL_FAILURE:
1302                     ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1303                     ASSERT_EQ(preparedModel, nullptr);
1304                     break;
1305                 case ExpectedResult::NOT_CRASH:
1306                     ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1307                     break;
1308                 default:
1309                     FAIL();
1310             }
1311         }
1312     }
1313 
1314     const uint32_t kSeed = std::get<uint32_t>(GetParam());
1315     std::mt19937 generator;
1316 };
1317 
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1318 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1319     if (!mIsCachingSupported) return;
1320     for (uint32_t i = 0; i < mNumModelCache; i++) {
1321         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1322                            [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
1323     }
1324 }
1325 
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1326 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1327     if (!mIsCachingSupported) return;
1328     for (uint32_t i = 0; i < mNumModelCache; i++) {
1329         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1330                            [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
1331     }
1332 }
1333 
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1334 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1335     if (!mIsCachingSupported) return;
1336     for (uint32_t i = 0; i < mNumDataCache; i++) {
1337         testCorruptedCache(ExpectedResult::NOT_CRASH,
1338                            [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
1339     }
1340 }
1341 
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1342 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1343     if (!mIsCachingSupported) return;
1344     for (uint32_t i = 0; i < mNumDataCache; i++) {
1345         testCorruptedCache(ExpectedResult::NOT_CRASH,
1346                            [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
1347     }
1348 }
1349 
TEST_P(CompilationCachingSecurityTest,WrongToken)1350 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1351     if (!mIsCachingSupported) return;
1352     testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1353         // Randomly flip one single bit in mToken.
1354         uint32_t ind =
1355                 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1356         mToken[ind] ^= (1U << getRandomInt(0, 7));
1357         *skip = false;
1358     });
1359 }
1360 
printCompilationCachingSecurityTest(const testing::TestParamInfo<CompilationCachingSecurityTestParam> & info)1361 std::string printCompilationCachingSecurityTest(
1362         const testing::TestParamInfo<CompilationCachingSecurityTestParam>& info) {
1363     const auto& [namedDevice, operandType, seed] = info.param;
1364     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1365     return gtestCompliantName(getName(namedDevice) + "_" + type + "_" + std::to_string(seed));
1366 }
1367 
1368 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1369                         testing::Combine(kNamedDeviceChoices, kOperandTypeChoices,
1370                                          testing::Range(0U, 10U)),
1371                         printCompilationCachingSecurityTest);
1372 
1373 }  // namespace android::hardware::neuralnetworks::V1_2::vts::functional
1374