1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest.app; 18 19 import android.content.Context; 20 import android.util.Log; 21 22 import androidx.test.InstrumentationRegistry; 23 24 import com.android.nn.benchmark.core.BenchmarkException; 25 import com.android.nn.benchmark.core.BenchmarkResult; 26 import com.android.nn.benchmark.core.NNTestBase; 27 import com.android.nn.benchmark.core.NnApiDelegationFailure; 28 import com.android.nn.benchmark.core.Processor; 29 import com.android.nn.benchmark.core.TestModels; 30 31 import java.io.IOException; 32 import java.util.ArrayList; 33 import java.util.Arrays; 34 import java.util.List; 35 import java.util.Optional; 36 import java.util.concurrent.Callable; 37 import java.util.concurrent.atomic.AtomicBoolean; 38 import java.util.stream.Collectors; 39 40 public interface AcceleratorSpecificTestSupport { 41 String TAG = "AcceleratorTest"; 42 findTestModelRunningOnAccelerator( Context context, String acceleratorName)43 static Optional<TestModels.TestModelEntry> findTestModelRunningOnAccelerator( 44 Context context, String acceleratorName) throws NnApiDelegationFailure { 45 for (TestModels.TestModelEntry model : TestModels.modelsList()) { 46 if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) { 47 return Optional.of(model); 48 } 49 } 50 return Optional.empty(); 51 } 52 findAllTestModelsRunningOnAccelerator( Context context, String acceleratorName)53 static List<TestModels.TestModelEntry> findAllTestModelsRunningOnAccelerator( 54 Context context, String acceleratorName) throws NnApiDelegationFailure { 55 List<TestModels.TestModelEntry> result = new ArrayList<>(); 56 for (TestModels.TestModelEntry model : TestModels.modelsList()) { 57 if (Processor.isTestModelSupportedByAccelerator(context, model, acceleratorName)) { 58 result.add(model); 59 } 60 } 61 return result; 62 } 63 ramdomInRange(long min, long max)64 default long ramdomInRange(long min, long max) { 65 return min + (long) (Math.random() * (max - min)); 66 } 67 getTestParameter(String key, String defaultValue)68 static String getTestParameter(String key, String defaultValue) { 69 return InstrumentationRegistry.getArguments().getString(key, defaultValue); 70 } 71 getBooleanTestParameter(String key, boolean defaultValue)72 static boolean getBooleanTestParameter(String key, boolean defaultValue) { 73 // All instrumentation arguments are passed as String so I have to convert the value here. 74 return Boolean.parseBoolean( 75 InstrumentationRegistry.getArguments().getString(key, "" + defaultValue)); 76 } 77 78 static final String ACCELERATOR_FILTER_PROPERTY = "nnCrashtestDeviceFilter"; 79 static final String INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY = 80 "nnCrashtestIncludeNnapiReference"; 81 getTargetAcceleratorNames()82 static List<String> getTargetAcceleratorNames() { 83 List<String> accelerators = new ArrayList<>(); 84 String acceleratorFilter = getTestParameter(ACCELERATOR_FILTER_PROPERTY, ".+"); 85 accelerators.addAll(NNTestBase.availableAcceleratorNames().stream().filter( 86 name -> name.matches(acceleratorFilter)).collect( 87 Collectors.toList())); 88 if (getBooleanTestParameter(INCLUDE_NNAPI_SELECTED_ACCELERATOR_PROPERTY, false)) { 89 accelerators.add(null); // running tests with no specified target accelerator too 90 } 91 return accelerators; 92 } 93 94 perAcceleratorTestConfig(List<Object[]> testConfig)95 static List<Object[]> perAcceleratorTestConfig(List<Object[]> testConfig) { 96 return testConfig.stream() 97 .flatMap(currConfigurationParams -> getTargetAcceleratorNames().stream().map( 98 accelerator -> { 99 Object[] result = 100 Arrays.copyOf(currConfigurationParams, 101 currConfigurationParams.length + 1); 102 result[currConfigurationParams.length] = accelerator; 103 return result; 104 })) 105 .collect(Collectors.toList()); 106 } 107 108 class DriverLivenessChecker implements Callable<Boolean> { 109 final Processor mProcessor; 110 private final AtomicBoolean mRun = new AtomicBoolean(true); 111 private final TestModels.TestModelEntry mTestModelEntry; 112 DriverLivenessChecker(Context context, String acceleratorName, TestModels.TestModelEntry testModelEntry)113 public DriverLivenessChecker(Context context, String acceleratorName, 114 TestModels.TestModelEntry testModelEntry) { 115 mProcessor = new Processor(context, 116 new Processor.Callback() { 117 @Override 118 public void onBenchmarkFinish(boolean ok) { 119 } 120 121 @Override 122 public void onStatusUpdate(int testNumber, int numTests, String modelName) { 123 } 124 }, new int[0]); 125 mProcessor.setUseNNApi(true); 126 mProcessor.setCompleteInputSet(false); 127 mProcessor.setNnApiAcceleratorName(acceleratorName); 128 mTestModelEntry = testModelEntry; 129 } 130 stop()131 public void stop() { 132 mRun.set(false); 133 } 134 135 @Override call()136 public Boolean call() throws Exception { 137 while (mRun.get()) { 138 try { 139 BenchmarkResult modelExecutionResult = mProcessor.getInstrumentationResult( 140 mTestModelEntry, 0, 3); 141 if (modelExecutionResult.hasBenchmarkError()) { 142 Log.e(TAG, String.format("Benchmark failed with message %s", 143 modelExecutionResult.getBenchmarkError())); 144 return false; 145 } 146 } catch (IOException | BenchmarkException e) { 147 Log.e(TAG, String.format("Error running model %s", mTestModelEntry.mModelName)); 148 return false; 149 } 150 } 151 152 return true; 153 } 154 } 155 }