1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import static java.util.concurrent.TimeUnit.MILLISECONDS; 20 21 import android.content.Context; 22 import android.os.Trace; 23 import android.util.Log; 24 import android.util.Pair; 25 26 import java.io.IOException; 27 import java.util.Collections; 28 import java.util.List; 29 import java.util.concurrent.CountDownLatch; 30 import java.util.concurrent.atomic.AtomicBoolean; 31 32 /** Processor is a helper thread for running the work without blocking the UI thread. */ 33 public class Processor implements Runnable { 34 35 36 public interface Callback { onBenchmarkFinish(boolean ok)37 void onBenchmarkFinish(boolean ok); 38 onStatusUpdate(int testNumber, int numTests, String modelName)39 void onStatusUpdate(int testNumber, int numTests, String modelName); 40 } 41 42 protected static final String TAG = "NN_BENCHMARK"; 43 private Context mContext; 44 45 private final AtomicBoolean mRun = new AtomicBoolean(true); 46 47 volatile boolean mHasBeenStarted = false; 48 // You cannot restart a thread, so the completion flag is final 49 private final CountDownLatch mCompleted = new CountDownLatch(1); 50 private NNTestBase mTest; 51 private int mTestList[]; 52 private BenchmarkResult mTestResults[]; 53 54 private Processor.Callback mCallback; 55 56 private boolean mUseNNApi; 57 private boolean mMmapModel; 58 private boolean mCompleteInputSet; 59 private boolean mToggleLong; 60 private boolean mTogglePause; 61 private String mAcceleratorName; 62 private boolean mIgnoreUnsupportedModels; 63 private boolean mRunModelCompilationOnly; 64 // Max number of benchmark iterations to do in run method. 65 // Less or equal to 0 means unlimited 66 private int mMaxRunIterations; 67 68 private boolean mBenchmarkCompilationCaching; 69 private float mCompilationBenchmarkWarmupTimeSeconds; 70 private float mCompilationBenchmarkRunTimeSeconds; 71 private int mCompilationBenchmarkMaxIterations; 72 Processor(Context context, Processor.Callback callback, int[] testList)73 public Processor(Context context, Processor.Callback callback, int[] testList) { 74 mContext = context; 75 mCallback = callback; 76 mTestList = testList; 77 if (mTestList != null) { 78 mTestResults = new BenchmarkResult[mTestList.length]; 79 } 80 mAcceleratorName = null; 81 mIgnoreUnsupportedModels = false; 82 mRunModelCompilationOnly = false; 83 mMaxRunIterations = 0; 84 mBenchmarkCompilationCaching = false; 85 } 86 setUseNNApi(boolean useNNApi)87 public void setUseNNApi(boolean useNNApi) { 88 mUseNNApi = useNNApi; 89 } 90 setCompleteInputSet(boolean completeInputSet)91 public void setCompleteInputSet(boolean completeInputSet) { 92 mCompleteInputSet = completeInputSet; 93 } 94 setToggleLong(boolean toggleLong)95 public void setToggleLong(boolean toggleLong) { 96 mToggleLong = toggleLong; 97 } 98 setTogglePause(boolean togglePause)99 public void setTogglePause(boolean togglePause) { 100 mTogglePause = togglePause; 101 } 102 setNnApiAcceleratorName(String acceleratorName)103 public void setNnApiAcceleratorName(String acceleratorName) { 104 mAcceleratorName = acceleratorName; 105 } 106 setIgnoreUnsupportedModels(boolean value)107 public void setIgnoreUnsupportedModels(boolean value) { 108 mIgnoreUnsupportedModels = value; 109 } 110 setRunModelCompilationOnly(boolean value)111 public void setRunModelCompilationOnly(boolean value) { 112 mRunModelCompilationOnly = value; 113 } 114 setMmapModel(boolean value)115 public void setMmapModel(boolean value) { 116 mMmapModel = value; 117 } 118 setMaxRunIterations(int value)119 public void setMaxRunIterations(int value) { 120 mMaxRunIterations = value; 121 } 122 enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)123 public void enableCompilationCachingBenchmarks( 124 float warmupTimeSeconds, float runTimeSeconds, int maxIterations) { 125 mBenchmarkCompilationCaching = true; 126 mCompilationBenchmarkWarmupTimeSeconds = warmupTimeSeconds; 127 mCompilationBenchmarkRunTimeSeconds = runTimeSeconds; 128 mCompilationBenchmarkMaxIterations = maxIterations; 129 } 130 131 // Method to retrieve benchmark results for instrumentation tests. 132 // Returns null if the processor is configured to run compilation only getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)133 public BenchmarkResult getInstrumentationResult( 134 TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds) 135 throws IOException, BenchmarkException { 136 mTest = changeTest(mTest, t); 137 try { 138 BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(warmupTimeSeconds, 139 runTimeSeconds); 140 return result; 141 } finally { 142 mTest.destroy(); 143 mTest = null; 144 } 145 } 146 isTestModelSupportedByAccelerator(Context context, TestModels.TestModelEntry testModelEntry, String acceleratorName)147 public static boolean isTestModelSupportedByAccelerator(Context context, 148 TestModels.TestModelEntry testModelEntry, String acceleratorName) 149 throws NnApiDelegationFailure { 150 try (NNTestBase tb = testModelEntry.createNNTestBase(/*useNnnapi=*/ true, 151 /*enableIntermediateTensorsDump=*/false, 152 /*mmapModel=*/ false)) { 153 tb.setNNApiDeviceName(acceleratorName); 154 return tb.setupModel(context); 155 } catch (IOException e) { 156 Log.w(TAG, 157 String.format("Error trying to check support for model %s on accelerator %s", 158 testModelEntry.mModelName, acceleratorName), e); 159 return false; 160 } catch (NnApiDelegationFailure nnApiDelegationFailure) { 161 if (nnApiDelegationFailure.getNnApiErrno() == 4 /*ANEURALNETWORKS_BAD_DATA*/) { 162 // Compilation will fail with ANEURALNETWORKS_BAD_DATA if the device is not 163 // supporting all operation in the model 164 return false; 165 } 166 167 throw nnApiDelegationFailure; 168 } 169 } 170 changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)171 private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t) 172 throws IOException, UnsupportedModelException, NnApiDelegationFailure { 173 if (oldTestBase != null) { 174 // Make sure we don't leak memory. 175 oldTestBase.destroy(); 176 } 177 NNTestBase tb = t.createNNTestBase(mUseNNApi, /*enableIntermediateTensorsDump=*/false, 178 mMmapModel); 179 if (mUseNNApi) { 180 tb.setNNApiDeviceName(mAcceleratorName); 181 } 182 if (!tb.setupModel(mContext)) { 183 throw new UnsupportedModelException("Cannot initialise model"); 184 } 185 return tb; 186 } 187 188 // Run one loop of kernels for at least the specified minimum time. 189 // The function returns the average time in ms for the test run runBenchmarkLoop(float minTime, boolean completeInputSet)190 private BenchmarkResult runBenchmarkLoop(float minTime, boolean completeInputSet) 191 throws IOException { 192 try { 193 // Run the kernel 194 Pair<List<InferenceInOutSequence>, List<InferenceResult>> results; 195 if (minTime > 0.f) { 196 if (completeInputSet) { 197 results = mTest.runBenchmarkCompleteInputSet(1, minTime); 198 } else { 199 results = mTest.runBenchmark(minTime); 200 } 201 } else { 202 results = mTest.runInferenceOnce(); 203 } 204 return BenchmarkResult.fromInferenceResults( 205 mTest.getTestInfo(), 206 mUseNNApi 207 ? BenchmarkResult.BACKEND_TFLITE_NNAPI 208 : BenchmarkResult.BACKEND_TFLITE_CPU, 209 results.first, 210 results.second, 211 mTest.getEvaluator()); 212 } catch (BenchmarkException e) { 213 return new BenchmarkResult(e.getMessage()); 214 } 215 } 216 217 // Run one loop of compilations for at least the specified minimum time. 218 // The function will set the compilation results into the provided benchmark result object. runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, int maxIterations, BenchmarkResult benchmarkResult)219 private void runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, 220 int maxIterations, BenchmarkResult benchmarkResult) throws IOException { 221 try { 222 CompilationBenchmarkResult result = 223 mTest.runCompilationBenchmark(warmupMinTime, runMinTime, maxIterations); 224 benchmarkResult.setCompilationBenchmarkResult(result); 225 } catch (BenchmarkException e) { 226 benchmarkResult.setBenchmarkError(e.getMessage()); 227 } 228 } 229 getTestResults()230 public BenchmarkResult[] getTestResults() { 231 return mTestResults; 232 } 233 234 // Get a benchmark result for a specific test getBenchmark(float warmupTimeSeconds, float runTimeSeconds)235 private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds) 236 throws IOException { 237 try { 238 mTest.checkSdkVersion(); 239 } catch (UnsupportedSdkException e) { 240 BenchmarkResult r = new BenchmarkResult(e.getMessage()); 241 Log.w(TAG, "Unsupported SDK for test: " + r.toString()); 242 return r; 243 } 244 245 // We run a short bit of work before starting the actual test 246 // this is to let any power management do its job and respond. 247 // For NNAPI systrace usage documentation, see 248 // frameworks/ml/nn/common/include/Tracing.h. 249 try { 250 final String traceName = "[NN_LA_PWU]runBenchmarkLoop"; 251 Trace.beginSection(traceName); 252 runBenchmarkLoop(warmupTimeSeconds, false); 253 } finally { 254 Trace.endSection(); 255 } 256 257 // Run the actual benchmark 258 BenchmarkResult r; 259 try { 260 final String traceName = "[NN_LA_PBM]runBenchmarkLoop"; 261 Trace.beginSection(traceName); 262 r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet); 263 } finally { 264 Trace.endSection(); 265 } 266 267 // Compilation benchmark 268 if (mUseNNApi && mBenchmarkCompilationCaching) { 269 runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds, 270 mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r); 271 } 272 273 return r; 274 } 275 276 @Override run()277 public void run() { 278 mHasBeenStarted = true; 279 Log.d(TAG, "Processor starting"); 280 boolean success = true; 281 int benchmarkIterationsCount = 0; 282 try { 283 while (mRun.get()) { 284 if (mMaxRunIterations > 0 && benchmarkIterationsCount >= mMaxRunIterations) { 285 break; 286 } 287 benchmarkIterationsCount++; 288 try { 289 benchmarkAllModels(); 290 } catch (IOException | BenchmarkException e) { 291 Log.e(TAG, "Exception during benchmark run", e); 292 success = false; 293 break; 294 } catch (Throwable e) { 295 Log.e(TAG, "Error during execution", e); 296 throw e; 297 } 298 } 299 Log.d(TAG, "Processor completed work"); 300 mCallback.onBenchmarkFinish(success); 301 } finally { 302 if (mTest != null) { 303 // Make sure we don't leak memory. 304 mTest.destroy(); 305 mTest = null; 306 } 307 mCompleted.countDown(); 308 } 309 } 310 benchmarkAllModels()311 private void benchmarkAllModels() throws IOException, BenchmarkException { 312 // Loop over the tests we want to benchmark 313 for (int ct = 0; ct < mTestList.length; ct++) { 314 if (!mRun.get()) { 315 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct)); 316 break; 317 } 318 // For reproducibility we wait a short time for any sporadic work 319 // created by the user touching the screen to launch the test to pass. 320 // Also allows for things to settle after the test changes. 321 try { 322 Thread.sleep(250); 323 } catch (InterruptedException ignored) { 324 Thread.currentThread().interrupt(); 325 break; 326 } 327 328 TestModels.TestModelEntry testModel = 329 TestModels.modelsList().get(mTestList[ct]); 330 331 int testNumber = ct + 1; 332 mCallback.onStatusUpdate(testNumber, mTestList.length, 333 testModel.toString()); 334 335 // Select the next test 336 try { 337 mTest = changeTest(mTest, testModel); 338 } catch (UnsupportedModelException e) { 339 if (mIgnoreUnsupportedModels) { 340 Log.d(TAG, String.format( 341 "Cannot initialise test %d: '%s' on accelerator %s, skipping", ct, 342 testModel.mTestName, mAcceleratorName)); 343 } else { 344 Log.e(TAG, 345 String.format("Cannot initialise test %d: '%s' on accelerator %s.", ct, 346 testModel.mTestName, mAcceleratorName), e); 347 throw e; 348 } 349 } 350 351 // If the user selected the "long pause" option, wait 352 if (mTogglePause) { 353 for (int i = 0; (i < 100) && mRun.get(); i++) { 354 try { 355 Thread.sleep(100); 356 } catch (InterruptedException ignored) { 357 Thread.currentThread().interrupt(); 358 break; 359 } 360 } 361 } 362 363 if (mRunModelCompilationOnly) { 364 mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName, 365 mUseNNApi 366 ? BenchmarkResult.BACKEND_TFLITE_NNAPI 367 : BenchmarkResult.BACKEND_TFLITE_CPU, Collections.emptyList(), 368 Collections.emptyList(), null); 369 } else { 370 // Run the test 371 float warmupTime = 0.3f; 372 float runTime = 1.f; 373 if (mToggleLong) { 374 warmupTime = 2.f; 375 runTime = 10.f; 376 } 377 mTestResults[ct] = getBenchmark(warmupTime, runTime); 378 } 379 } 380 } 381 exit()382 public void exit() { 383 exitWithTimeout(-1l); 384 } 385 exitWithTimeout(long timeoutMs)386 public void exitWithTimeout(long timeoutMs) { 387 mRun.set(false); 388 389 if (mHasBeenStarted) { 390 Log.d(TAG, String.format("Terminating, timeout is %d ms", timeoutMs)); 391 try { 392 if (timeoutMs > 0) { 393 boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS); 394 if (!hasCompleted) { 395 Log.w(TAG, "Exiting before execution actually completed"); 396 } 397 } else { 398 mCompleted.await(); 399 } 400 } catch (InterruptedException e) { 401 Thread.currentThread().interrupt(); 402 Log.w(TAG, "Interrupted while waiting for Processor to complete", e); 403 } 404 } 405 406 Log.d(TAG, "Done, cleaning up"); 407 408 if (mTest != null) { 409 mTest.destroy(); 410 mTest = null; 411 } 412 } 413 } 414