1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.app;
18 
19 import android.content.Context;
20 import android.util.Log;
21 
22 import androidx.test.InstrumentationRegistry;
23 
24 import com.android.nn.benchmark.core.BenchmarkException;
25 import com.android.nn.benchmark.core.BenchmarkResult;
26 import com.android.nn.benchmark.core.NNTestBase;
27 import com.android.nn.benchmark.core.NnApiDelegationFailure;
28 import com.android.nn.benchmark.core.Processor;
29 import com.android.nn.benchmark.core.TestModels;
30 
31 import java.io.IOException;
32 import java.util.ArrayList;
33 import java.util.Arrays;
34 import java.util.List;
35 import java.util.Optional;
36 import java.util.concurrent.Callable;
37 import java.util.concurrent.atomic.AtomicBoolean;
38 import java.util.stream.Collectors;
39 
40 public interface AcceleratorSpecificTestSupport {
41     String TAG = "AcceleratorTest";
42 
findTestModelRunningOnAccelerator( Context context, String acceleratorName)43     static Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator(
44             Context context, String acceleratorName) throws NnApiDelegationFailure {
45         for (TestModels.TestModelEntry model : TestModels.modelsList()) {
46             if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
47                 return Optional.of(model);
48             }
49         }
50         return Optional.empty();
51     }
52 
findAllTestModelsRunningOnAccelerator( Context context, String acceleratorName)53     static List<TestModels.TestModelEntry> findAllTestModelsRunningOnAccelerator(
54             Context context, String acceleratorName) throws NnApiDelegationFailure {
55         List<TestModels.TestModelEntry> result = new ArrayList<>();
56         for (TestModels.TestModelEntry model : TestModels.modelsList()) {
57             if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) {
58                 result.add(model);
59             }
60         }
61         return result;
62     }
63 
ramdomInRange(long min, long max)64     default long ramdomInRange(long min, long max) {
65         return min + (long) (Math.random() * (max - min));
66     }
67 
getTestParameter(String key, String defaultValue)68     static String getTestParameter(String key, String defaultValue) {
69         return InstrumentationRegistry.getArguments().getString(key, defaultValue);
70     }
71 
getBooleanTestParameter(String key, boolean defaultValue)72     static boolean getBooleanTestParameter(String key, boolean defaultValue) {
73         // All instrumentation arguments are passed as String so I have to convert the value here.
74         return Boolean.parseBoolean(
75                 InstrumentationRegistry.getArguments().getString(key, "" + defaultValue));
76     }
77 
78     static final String ACCELERATOR_FILTER_PROPERTY = "nnCrashtestDeviceFilter";
79     static final String INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY =
80             "nnCrashtestIncludeNnapiReference";
81 
getTargetAcceleratorNames()82     static List<String> getTargetAcceleratorNames() {
83         List<String> accelerators = new ArrayList<>();
84         String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, ".+");
85         accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter(
86                 name -> name.matches(acceleratorFilter)).collect(
87                 Collectors.toList()));
88         if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) {
89             accelerators.add(null); // running tests with no specified target accelerator too
90         }
91         return accelerators;
92     }
93 
94 
perAcceleratorTestConfig(List<Object[]> testConfig)95     static List<Object[]> perAcceleratorTestConfig(List<Object[]> testConfig) {
96         return testConfig.stream()
97                 .flatMap(currConfigurationParams -> getTargetAcceleratorNames().stream().map(
98                         accelerator -> {
99                             Object[] result =
100                                     Arrays.copyOf(currConfigurationParams,
101                                             currConfigurationParams.length + 1);
102                             result[currConfigurationParams.length] = accelerator;
103                             return result;
104                         }))
105                 .collect(Collectors.toList());
106     }
107 
108     class DriverLivenessChecker implements Callable<Boolean> {
109         final Processor mProcessor;
110         private final AtomicBoolean mRun = new AtomicBoolean(true);
111         private final TestModels.TestModelEntry mTestModelEntry;
112 
DriverLivenessChecker(Context context, String acceleratorName, TestModels.TestModelEntry testModelEntry)113         public DriverLivenessChecker(Context context, String acceleratorName,
114                 TestModels.TestModelEntry testModelEntry) {
115             mProcessor = new Processor(context,
116                     new Processor.Callback() {
117                         @Override
118                         public void onBenchmarkFinish(boolean ok) {
119                         }
120 
121                         @Override
122                         public void onStatusUpdate(int testNumber, int numTests, String modelName) {
123                         }
124                     }, new int[0]);
125             mProcessor.setUseNNApi(true);
126             mProcessor.setCompleteInputSet(false);
127             mProcessor.setNnApiAcceleratorName(acceleratorName);
128             mTestModelEntry = testModelEntry;
129         }
130 
stop()131         public void stop() {
132             mRun.set(false);
133         }
134 
135         @Override
call()136         public Boolean call() throws Exception {
137             while (mRun.get()) {
138                 try {
139                     BenchmarkResult modelExecutionResult = mProcessor.getInstrumentationResult(
140                             mTestModelEntry, 0, 3);
141                     if (modelExecutionResult.hasBenchmarkError()) {
142                         Log.e(TAG, String.format("Benchmark failed with message %s",
143                                 modelExecutionResult.getBenchmarkError()));
144                         return false;
145                     }
146                 } catch (IOException | BenchmarkException e) {
147                     Log.e(TAG, String.format("Error running model %s", mTestModelEntry.mModelName));
148                     return false;
149                 }
150             }
151 
152             return true;
153         }
154     }
155 }