1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <android-base/scopeguard.h>
18 #include <gtest/gtest.h>
19 
20 #include <cstdlib>
21 #include <filesystem>
22 #include <numeric>
23 #include <string>
24 #include <string_view>
25 #include <tuple>
26 #include <vector>
27 
28 #include "HalInterfaces.h"
29 #include "Manager.h"
30 #include "SampleDriver.h"
31 #include "TestNeuralNetworksWrapper.h"
32 
33 using namespace android::nn;
34 using namespace hal;
35 using Result = test_wrapper::Result;
36 using Type = test_wrapper::Type;
37 const Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
38 template <typename T>
39 using MQDescriptorSync = ::android::hardware::MQDescriptorSync<T>;
40 
41 namespace android::hardware::neuralnetworks::V1_0 {
42 
operator <<(::std::ostream & os,ErrorStatus errorStatus)43 ::std::ostream& operator<<(::std::ostream& os, ErrorStatus errorStatus) {
44     return os << toString(errorStatus);
45 }
46 
47 }  // namespace android::hardware::neuralnetworks::V1_0
48 
49 namespace {
50 
51 enum class HasCalledPrepareModel { NO, WITHOUT_CACHING, WITH_CACHING };
52 
53 // Print HasCalledPrepareModel enum for better GTEST failure messages
operator <<(std::ostream & os,HasCalledPrepareModel hasCalledPrepareModel)54 std::ostream& operator<<(std::ostream& os, HasCalledPrepareModel hasCalledPrepareModel) {
55     switch (hasCalledPrepareModel) {
56         case HasCalledPrepareModel::NO:
57             return os << "NO";
58         case HasCalledPrepareModel::WITHOUT_CACHING:
59             return os << "WITHOUT_CACHING";
60         case HasCalledPrepareModel::WITH_CACHING:
61             return os << "WITH_CACHING";
62     }
63     CHECK(false) << "HasCalledPrepareModel print called with invalid code "
64                  << static_cast<int>(hasCalledPrepareModel);
65     return os;
66 }
67 
68 // Whether the driver is expected to be registered because it can pass initialization.
canDeviceBeRegistered(ErrorStatus error,uint32_t numModelCache,uint32_t numDataCache)69 bool canDeviceBeRegistered(ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
70     constexpr uint32_t maxNumCacheFiles =
71             static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES);
72     return error == ErrorStatus::NONE && numModelCache <= maxNumCacheFiles &&
73            numDataCache <= maxNumCacheFiles;
74 }
75 
76 // Whether the driver supports caching based on the returns from getNumberOfCacheFilesNeeded.
isCachingSupported(uint32_t numModelCache,uint32_t numDataCache)77 bool isCachingSupported(uint32_t numModelCache, uint32_t numDataCache) {
78     return numModelCache != 0 || numDataCache != 0;
79 }
80 
81 // This is an IDevice for testing purposes which overrides several methods from sample driver:
82 // - supports all the operations and is faster than cpu fallback.
83 // - overrides getNumberOfCacheFilesNeeded to report according to given parameters.
84 // - overrides prepareModelFromCache_1_3 to return error status according to
85 //   mErrorStatusPrepareFromCache.
86 // - produces CachingPreparedModel on prepareModel and prepareModelFromCache_1_3.
87 //
88 // The cache entry is written by prepareModel_1_3 and is checked later by
89 // CachingDriver::prepareModelFromCache_1_3.
90 //
91 // The CachingDriver has 2 flags mHasCalledPrepareModelFromCache and mHasCalledPrepareModel
92 // to check if the correct methods are invoked by the runtime.
93 class CachingDriver : public sample_driver::SampleDriver {
94    private:
95     static constexpr size_t kCacheSize = 256;
96 
97     class CachingPreparedModel : public IPreparedModel {
98        public:
99         CachingPreparedModel() = default;
100 
execute(const V1_0::Request &,const sp<V1_0::IExecutionCallback> &)101         Return<V1_0::ErrorStatus> execute(const V1_0::Request&,
102                                           const sp<V1_0::IExecutionCallback>&) override {
103             return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
104         }
execute_1_2(const V1_0::Request &,MeasureTiming,const sp<V1_2::IExecutionCallback> &)105         Return<V1_0::ErrorStatus> execute_1_2(const V1_0::Request&, MeasureTiming,
106                                               const sp<V1_2::IExecutionCallback>&) override {
107             return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
108         }
execute_1_3(const V1_3::Request &,MeasureTiming,const OptionalTimePoint &,const OptionalTimeoutDuration &,const sp<V1_3::IExecutionCallback> &)109         Return<V1_3::ErrorStatus> execute_1_3(const V1_3::Request&, MeasureTiming,
110                                               const OptionalTimePoint&,
111                                               const OptionalTimeoutDuration&,
112                                               const sp<V1_3::IExecutionCallback>&) override {
113             return V1_3::ErrorStatus::DEVICE_UNAVAILABLE;
114         }
executeSynchronously(const V1_0::Request &,MeasureTiming,executeSynchronously_cb cb)115         Return<void> executeSynchronously(const V1_0::Request&, MeasureTiming,
116                                           executeSynchronously_cb cb) override {
117             cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
118             return Void();
119         }
executeSynchronously_1_3(const V1_3::Request &,MeasureTiming,const OptionalTimePoint &,const OptionalTimeoutDuration &,executeSynchronously_1_3_cb cb)120         Return<void> executeSynchronously_1_3(const V1_3::Request&, MeasureTiming,
121                                               const OptionalTimePoint&,
122                                               const OptionalTimeoutDuration&,
123                                               executeSynchronously_1_3_cb cb) override {
124             cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
125             return Void();
126         }
configureExecutionBurst(const sp<V1_2::IBurstCallback> &,const MQDescriptorSync<V1_2::FmqRequestDatum> &,const MQDescriptorSync<V1_2::FmqResultDatum> &,configureExecutionBurst_cb cb)127         Return<void> configureExecutionBurst(const sp<V1_2::IBurstCallback>&,
128                                              const MQDescriptorSync<V1_2::FmqRequestDatum>&,
129                                              const MQDescriptorSync<V1_2::FmqResultDatum>&,
130                                              configureExecutionBurst_cb cb) override {
131             cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, nullptr);
132             return Void();
133         }
executeFenced(const hal::Request &,const hidl_vec<hidl_handle> &,MeasureTiming,const OptionalTimePoint &,const OptionalTimeoutDuration &,const OptionalTimeoutDuration &,executeFenced_cb cb)134         Return<void> executeFenced(const hal::Request&, const hidl_vec<hidl_handle>&, MeasureTiming,
135                                    const OptionalTimePoint&, const OptionalTimeoutDuration&,
136                                    const OptionalTimeoutDuration&, executeFenced_cb cb) {
137             cb(ErrorStatus::DEVICE_UNAVAILABLE, hidl_handle(nullptr), nullptr);
138             return Void();
139         }
140     };
141 
142    public:
CachingDriver(std::string_view name,ErrorStatus errorStatusGetNumCacheFiles,uint32_t numModelCache,uint32_t numDataCache,ErrorStatus errorStatusPrepareFromCache)143     CachingDriver(std::string_view name, ErrorStatus errorStatusGetNumCacheFiles,
144                   uint32_t numModelCache, uint32_t numDataCache,
145                   ErrorStatus errorStatusPrepareFromCache)
146         : SampleDriver(name.data()),
147           mErrorStatusGetNumCacheFiles(errorStatusGetNumCacheFiles),
148           mNumModelCache(numModelCache),
149           mNumDataCache(numDataCache),
150           mErrorStatusPrepareFromCache(errorStatusPrepareFromCache) {
151         mModelCacheData.resize(kCacheSize);
152         std::iota(mModelCacheData.begin(), mModelCacheData.end(), 0);
153         mDataCacheData.resize(kCacheSize);
154         std::iota(mDataCacheData.begin(), mDataCacheData.end(), 1);
155     }
~CachingDriver()156     ~CachingDriver() override {}
157 
158     // Reports faster than cpu.
getCapabilities_1_3(getCapabilities_1_3_cb cb)159     Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
160         android::nn::initVLogMask();
161         const PerformanceInfo kPerf = {.execTime = 0.1, .powerUsage = 0.1};
162         Capabilities capabilities = {
163                 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
164                 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
165                 .operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_3>(kPerf),
166                 .ifPerformance = kPerf,
167                 .whilePerformance = kPerf};
168         cb(V1_3::ErrorStatus::NONE, capabilities);
169         return Void();
170     }
171 
172     // Reports supporting all operations.
getSupportedOperations_1_3(const Model & model,getSupportedOperations_1_3_cb cb)173     Return<void> getSupportedOperations_1_3(const Model& model,
174                                             getSupportedOperations_1_3_cb cb) override {
175         std::vector<bool> supported(model.main.operations.size(), true);
176         cb(V1_3::ErrorStatus::NONE, supported);
177         return Void();
178     }
179 
180     // Reports according to mGetNumCacheFiles.
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)181     Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override {
182         cb(convertToV1_0(mErrorStatusGetNumCacheFiles), mNumModelCache, mNumDataCache);
183         return Void();
184     }
185 
186     // Generates CachingPreparedModel.
187     // Writes the cache entry per mCacheXData and sets mHasCalledPrepareModel.
prepareModel_1_3(const Model &,ExecutionPreference,Priority,const OptionalTimePoint &,const hidl_vec<hidl_handle> & modelCacheHandle,const hidl_vec<hidl_handle> & dataCacheHandle,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & cb)188     Return<V1_3::ErrorStatus> prepareModel_1_3(
189             const Model&, ExecutionPreference, Priority, const OptionalTimePoint&,
190             const hidl_vec<hidl_handle>& modelCacheHandle,
191             const hidl_vec<hidl_handle>& dataCacheHandle, const CacheToken&,
192             const sp<V1_3::IPreparedModelCallback>& cb) override {
193         checkNumberOfCacheHandles(modelCacheHandle.size(), dataCacheHandle.size());
194         if (modelCacheHandle.size() != 0 || dataCacheHandle.size() != 0) {
195             writeToCache(modelCacheHandle, mModelCacheData);
196             writeToCache(dataCacheHandle, mDataCacheData);
197             mHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
198         } else {
199             mHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
200         }
201         cb->notify_1_3(V1_3::ErrorStatus::NONE, new CachingPreparedModel());
202         return V1_3::ErrorStatus::NONE;
203     }
204 
205     // Checks if the cache entry is correct, notifies error status according to
206     // mErrorStatusPrepareFromCache, sets mHasCalledPrepareModelFromCache.
prepareModelFromCache_1_3(const OptionalTimePoint &,const hidl_vec<hidl_handle> & modelCacheHandle,const hidl_vec<hidl_handle> & dataCacheHandle,const CacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)207     Return<V1_3::ErrorStatus> prepareModelFromCache_1_3(
208             const OptionalTimePoint&, const hidl_vec<hidl_handle>& modelCacheHandle,
209             const hidl_vec<hidl_handle>& dataCacheHandle, const CacheToken&,
210             const sp<V1_3::IPreparedModelCallback>& callback) override {
211         readFromCache(modelCacheHandle, mModelCacheData);
212         readFromCache(dataCacheHandle, mDataCacheData);
213         mHasCalledPrepareModelFromCache = true;
214         if (mErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
215             callback->notify_1_3(mErrorStatusPrepareFromCache, new CachingPreparedModel());
216         } else {
217             callback->notify_1_3(mErrorStatusPrepareFromCache, nullptr);
218         }
219         return V1_3::ErrorStatus::NONE;
220     };
221 
hasCalledPrepareModelFromCache() const222     bool hasCalledPrepareModelFromCache() const { return mHasCalledPrepareModelFromCache; }
hasCalledPrepareModel() const223     HasCalledPrepareModel hasCalledPrepareModel() const { return mHasCalledPrepareModel; }
224 
225    private:
226     // Checks the number of cache files passed to the driver from runtime.
checkNumberOfCacheHandles(size_t modelCache,size_t dataCache)227     void checkNumberOfCacheHandles(size_t modelCache, size_t dataCache) {
228         if (isCachingSupported(mNumModelCache, mNumDataCache)) {
229             if (modelCache != 0 || dataCache != 0) {
230                 ASSERT_EQ(modelCache, mNumModelCache);
231                 ASSERT_EQ(dataCache, mNumDataCache);
232             }
233         } else {
234             ASSERT_EQ(modelCache, 0ul);
235             ASSERT_EQ(dataCache, 0ul);
236         }
237     }
238 
writeToCache(const hidl_vec<hidl_handle> & handles,const std::vector<uint8_t> & cache)239     void writeToCache(const hidl_vec<hidl_handle>& handles, const std::vector<uint8_t>& cache) {
240         for (uint32_t i = 0; i < handles.size(); ++i) {
241             ASSERT_EQ(handles[i]->numFds, 1);
242             EXPECT_EQ(write(handles[i]->data[0], cache.data(), kCacheSize),
243                       static_cast<ssize_t>(kCacheSize));
244         }
245     }
246 
readFromCache(const hidl_vec<hidl_handle> & handles,const std::vector<uint8_t> & expected)247     void readFromCache(const hidl_vec<hidl_handle>& handles, const std::vector<uint8_t>& expected) {
248         for (uint32_t i = 0; i < handles.size(); ++i) {
249             ASSERT_EQ(handles[i]->numFds, 1);
250             std::vector<uint8_t> actual(kCacheSize);
251             EXPECT_EQ(read(handles[i]->data[0], actual.data(), kCacheSize),
252                       static_cast<ssize_t>(kCacheSize));
253             EXPECT_EQ(actual, expected);
254         }
255     }
256 
257     std::vector<uint8_t> mModelCacheData;
258     std::vector<uint8_t> mDataCacheData;
259 
260     const ErrorStatus mErrorStatusGetNumCacheFiles;
261     const uint32_t mNumModelCache;
262     const uint32_t mNumDataCache;
263     const ErrorStatus mErrorStatusPrepareFromCache;
264 
265     bool mHasCalledPrepareModelFromCache = false;
266     HasCalledPrepareModel mHasCalledPrepareModel = HasCalledPrepareModel::NO;
267 };
268 
CreateBroadcastAddModel(test_wrapper::Model * model)269 void CreateBroadcastAddModel(test_wrapper::Model* model) {
270     test_wrapper::OperandType matrixType(Type::TENSOR_FLOAT32, {2, 2});
271     test_wrapper::OperandType vectorType(Type::TENSOR_FLOAT32, {2});
272     test_wrapper::OperandType scalarType(Type::INT32, {});
273     int32_t activation(ANEURALNETWORKS_FUSED_NONE);
274     auto a = model->addOperand(&matrixType);
275     auto b = model->addOperand(&vectorType);
276     auto c = model->addOperand(&matrixType);
277     auto d = model->addOperand(&scalarType);
278     model->setOperandValue(d, &activation, sizeof(activation));
279     model->addOperation(ANEURALNETWORKS_ADD, {a, b, d}, {c});
280     model->identifyInputsAndOutputs({a, b}, {c});
281     ASSERT_TRUE(model->isValid());
282     ASSERT_EQ(model->finish(), Result::NO_ERROR);
283 }
284 
getDeviceWithName(std::string_view deviceName,const ANeuralNetworksDevice ** outputDevice)285 void getDeviceWithName(std::string_view deviceName, const ANeuralNetworksDevice** outputDevice) {
286     uint32_t numDevices = 0;
287     ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
288     EXPECT_GE(numDevices, (uint32_t)1);
289 
290     int numMatchingDevices = 0;
291     for (uint32_t i = 0; i < numDevices; i++) {
292         ANeuralNetworksDevice* device = nullptr;
293         ASSERT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
294 
295         const char* buffer = nullptr;
296         ASSERT_EQ(ANeuralNetworksDevice_getName(device, &buffer), ANEURALNETWORKS_NO_ERROR);
297         if (deviceName == buffer) {
298             *outputDevice = device;
299             numMatchingDevices++;
300         }
301     }
302 
303     EXPECT_LE(numMatchingDevices, 1);
304 }
305 
306 // Test device registration with a driver parameterized with
307 // - ErrorStatus returning from getNumberOfCacheFilesNeeded
308 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
309 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
310 using DeviceRegistrationTestParam = std::tuple<ErrorStatus, uint32_t, uint32_t>;
311 
312 class DeviceRegistrationTest : public ::testing::TestWithParam<DeviceRegistrationTestParam> {
313    protected:
314     static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
315     const ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam());
316     const uint32_t kNumModelCache = std::get<1>(GetParam());
317     const uint32_t kNumDataCache = std::get<2>(GetParam());
318     const sp<CachingDriver> kDriver =
319             new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
320                               kNumDataCache, ErrorStatus::NONE);
321 };
322 
TEST_P(DeviceRegistrationTest,CachingFailure)323 TEST_P(DeviceRegistrationTest, CachingFailure) {
324     if (DeviceManager::get()->getUseCpuOnly()) {
325         return;
326     }
327 
328     DeviceManager::get()->forTest_registerDevice(kDeviceName.data(), kDriver);
329     const auto cleanup = android::base::make_scope_guard(
330             [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
331 
332     // get device
333     const ANeuralNetworksDevice* device = nullptr;
334     getDeviceWithName(kDeviceName, &device);
335 
336     // check if device registeration matches expectations
337     const bool isDeviceRegistered = (device != nullptr);
338     const bool expectDeviceToBeRegistered =
339             canDeviceBeRegistered(kErrorStatusGetNumCacheFiles, kNumModelCache, kNumDataCache);
340     ASSERT_EQ(isDeviceRegistered, expectDeviceToBeRegistered);
341 }
342 
343 // Test model compilation with a driver parameterized with
344 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
345 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
346 // - ErrorStatus returning from prepareModelFromCache_1_3
347 using CompilationCachingTestParam = std::tuple<uint32_t, uint32_t, ErrorStatus>;
348 
349 class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachingTestParam> {
350    protected:
SetUp()351     virtual void SetUp() override {
352         char cacheDirTemp[] =
353                 "/data/local/tmp/AVeryLongDirectoryNameForTestCompilationCachingXXXXXX";
354         char* cacheDir = mkdtemp(cacheDirTemp);
355         ASSERT_NE(cacheDir, nullptr);
356         mCacheDir = cacheDir;
357         CreateBroadcastAddModel(&mModel);
358     }
359 
TearDown()360     virtual void TearDown() override {
361         if (!::testing::Test::HasFailure()) {
362             std::filesystem::remove_all(mCacheDir);
363         }
364     }
365 
compileModel(const sp<CachingDriver> & driver,bool withToken)366     void compileModel(const sp<CachingDriver>& driver, bool withToken) {
367         DeviceManager::get()->forTest_registerDevice(kDeviceName.data(), driver);
368         const auto cleanup = android::base::make_scope_guard(
369                 [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
370 
371         // Get a handle to the single driver device matching kDeviceName.
372         const ANeuralNetworksDevice* device = nullptr;
373         getDeviceWithName(kDeviceName, &device);
374         ASSERT_NE(device, nullptr);
375 
376         // Compile the model with the device.
377         ANeuralNetworksCompilation* compilation = nullptr;
378         ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), &device, 1,
379                                                               &compilation),
380                   ANEURALNETWORKS_NO_ERROR);
381         if (withToken) {
382             ASSERT_EQ(ANeuralNetworksCompilation_setCaching(compilation, mCacheDir.c_str(),
383                                                             kToken.data()),
384                       ANEURALNETWORKS_NO_ERROR);
385         }
386         ASSERT_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR);
387 
388         // close memory
389         ANeuralNetworksCompilation_free(compilation);
390     }
391 
createCache()392     void createCache() {
393         sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
394                                                      kNumDataCache, ErrorStatus::NONE);
395         compileModel(driver, /*withToken=*/true);
396     }
397 
398     static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
399     const uint32_t kNumModelCache = std::get<0>(GetParam());
400     const uint32_t kNumDataCache = std::get<1>(GetParam());
401     const ErrorStatus kErrorStatusPrepareFromCache = std::get<2>(GetParam());
402     const bool kIsCachingSupported = isCachingSupported(kNumModelCache, kNumDataCache);
403     test_wrapper::Model mModel;
404     std::string mCacheDir;
405     const CacheToken kToken{};
406 };
407 
TEST_P(CompilationCachingTest,TokenProvidedAndCacheNotExist)408 TEST_P(CompilationCachingTest, TokenProvidedAndCacheNotExist) {
409     if (DeviceManager::get()->getUseCpuOnly()) {
410         return;
411     }
412     sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
413                                                  kNumDataCache, kErrorStatusPrepareFromCache);
414     compileModel(driver, /*withToken=*/true);
415 
416     // When cache file does not exist, the runtime should never call prepareModelFromCache_1_3.
417     EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
418 
419     // The runtime should call prepareModel_1_3. It should request caching iff caching supported.
420     EXPECT_EQ(driver->hasCalledPrepareModel(), kIsCachingSupported
421                                                        ? HasCalledPrepareModel::WITH_CACHING
422                                                        : HasCalledPrepareModel::WITHOUT_CACHING);
423 }
424 
TEST_P(CompilationCachingTest,TokenProvidedAndCacheExist)425 TEST_P(CompilationCachingTest, TokenProvidedAndCacheExist) {
426     if (DeviceManager::get()->getUseCpuOnly()) {
427         return;
428     }
429     createCache();
430     sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
431                                                  kNumDataCache, kErrorStatusPrepareFromCache);
432     compileModel(driver, /*withToken=*/true);
433 
434     // When cache files exist, the runtime should call prepareModelFromCache_1_3 iff caching
435     // supported.
436     EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), kIsCachingSupported);
437 
438     HasCalledPrepareModel expectHasCalledPrepareModel;
439     if (kIsCachingSupported) {
440         if (kErrorStatusPrepareFromCache == ErrorStatus::NONE) {
441             // The runtime should not call prepareModel_1_3 iff caching supported and
442             // prepareModelFromCache_1_3 succeeds.
443             expectHasCalledPrepareModel = HasCalledPrepareModel::NO;
444         } else {
445             // The runtime should call prepareModel_1_3 and request caching iff caching supported
446             // but prepareModelFromCache_1_3 fails.
447             expectHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
448         }
449     } else {
450         // The runtime should call prepareModel_1_3 without caching iff caching not supported.
451         expectHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
452     }
453     EXPECT_EQ(driver->hasCalledPrepareModel(), expectHasCalledPrepareModel);
454 }
455 
TEST_P(CompilationCachingTest,TokenNotProvided)456 TEST_P(CompilationCachingTest, TokenNotProvided) {
457     if (DeviceManager::get()->getUseCpuOnly()) {
458         return;
459     }
460     sp<CachingDriver> driver = new CachingDriver(kDeviceName, ErrorStatus::NONE, kNumModelCache,
461                                                  kNumDataCache, kErrorStatusPrepareFromCache);
462     compileModel(driver, /*withToken=*/false);
463 
464     // When no NDK token is provided by the client, the runtime should never call
465     // prepareModelFromCache_1_3 or request caching with prepareModel_1_3.
466     EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
467     EXPECT_EQ(driver->hasCalledPrepareModel(), HasCalledPrepareModel::WITHOUT_CACHING);
468 }
469 
470 static const auto kErrorStatusGetNumCacheFilesChoices =
471         testing::Values(ErrorStatus::NONE, ErrorStatus::DEVICE_UNAVAILABLE);
472 static const auto kNumCacheChoices =
473         testing::Values(0ul, 1ul, static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES),
474                         static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES) + 1);
475 static const auto kNumValidCacheChoices =
476         testing::Values(0ul, 1ul, static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES));
477 static const auto kErrorStatusPrepareFromCacheChoices =
478         testing::Values(ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE,
479                         ErrorStatus::DEVICE_UNAVAILABLE, ErrorStatus::INVALID_ARGUMENT);
480 
481 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, DeviceRegistrationTest,
482                         testing::Combine(kErrorStatusGetNumCacheFilesChoices, kNumCacheChoices,
483                                          kNumCacheChoices));
484 
485 INSTANTIATE_TEST_CASE_P(TestCompilationCaching, CompilationCachingTest,
486                         testing::Combine(kNumValidCacheChoices, kNumValidCacheChoices,
487                                          kErrorStatusPrepareFromCacheChoices));
488 
489 }  // namespace
490