1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import static java.util.concurrent.TimeUnit.MILLISECONDS;
20 
21 import android.content.Context;
22 import android.os.Trace;
23 import android.util.Log;
24 import android.util.Pair;
25 
26 import java.io.IOException;
27 import java.util.Collections;
28 import java.util.List;
29 import java.util.concurrent.CountDownLatch;
30 import java.util.concurrent.atomic.AtomicBoolean;
31 
32 /** Processor is a helper thread for running the work without blocking the UI thread. */
33 public class Processor implements Runnable {
34 
35 
36     public interface Callback {
onBenchmarkFinish(boolean ok)37         void onBenchmarkFinish(boolean ok);
38 
onStatusUpdate(int testNumber, int numTests, String modelName)39         void onStatusUpdate(int testNumber, int numTests, String modelName);
40     }
41 
42     protected static final String TAG = "NN_BENCHMARK";
43     private Context mContext;
44 
45     private final AtomicBoolean mRun = new AtomicBoolean(true);
46 
47     volatile boolean mHasBeenStarted = false;
48     // You cannot restart a thread, so the completion flag is final
49     private final CountDownLatch mCompleted = new CountDownLatch(1);
50     private NNTestBase mTest;
51     private int mTestList[];
52     private BenchmarkResult mTestResults[];
53 
54     private Processor.Callback mCallback;
55 
56     private boolean mUseNNApi;
57     private boolean mMmapModel;
58     private boolean mCompleteInputSet;
59     private boolean mToggleLong;
60     private boolean mTogglePause;
61     private String mAcceleratorName;
62     private boolean mIgnoreUnsupportedModels;
63     private boolean mRunModelCompilationOnly;
64     // Max number of benchmark iterations to do in run method.
65     // Less or equal to 0 means unlimited
66     private int mMaxRunIterations;
67 
68     private boolean mBenchmarkCompilationCaching;
69     private float mCompilationBenchmarkWarmupTimeSeconds;
70     private float mCompilationBenchmarkRunTimeSeconds;
71     private int mCompilationBenchmarkMaxIterations;
72 
Processor(Context context, Processor.Callback callback, int[] testList)73     public Processor(Context context, Processor.Callback callback, int[] testList) {
74         mContext = context;
75         mCallback = callback;
76         mTestList = testList;
77         if (mTestList != null) {
78             mTestResults = new BenchmarkResult[mTestList.length];
79         }
80         mAcceleratorName = null;
81         mIgnoreUnsupportedModels = false;
82         mRunModelCompilationOnly = false;
83         mMaxRunIterations = 0;
84         mBenchmarkCompilationCaching = false;
85     }
86 
setUseNNApi(boolean useNNApi)87     public void setUseNNApi(boolean useNNApi) {
88         mUseNNApi = useNNApi;
89     }
90 
setCompleteInputSet(boolean completeInputSet)91     public void setCompleteInputSet(boolean completeInputSet) {
92         mCompleteInputSet = completeInputSet;
93     }
94 
setToggleLong(boolean toggleLong)95     public void setToggleLong(boolean toggleLong) {
96         mToggleLong = toggleLong;
97     }
98 
setTogglePause(boolean togglePause)99     public void setTogglePause(boolean togglePause) {
100         mTogglePause = togglePause;
101     }
102 
setNnApiAcceleratorName(String acceleratorName)103     public void setNnApiAcceleratorName(String acceleratorName) {
104         mAcceleratorName = acceleratorName;
105     }
106 
setIgnoreUnsupportedModels(boolean value)107     public void setIgnoreUnsupportedModels(boolean value) {
108         mIgnoreUnsupportedModels = value;
109     }
110 
setRunModelCompilationOnly(boolean value)111     public void setRunModelCompilationOnly(boolean value) {
112         mRunModelCompilationOnly = value;
113     }
114 
setMmapModel(boolean value)115     public void setMmapModel(boolean value) {
116         mMmapModel = value;
117     }
118 
setMaxRunIterations(int value)119     public void setMaxRunIterations(int value) {
120         mMaxRunIterations = value;
121     }
122 
enableCompilationCachingBenchmarks( float warmupTimeSeconds, float runTimeSeconds, int maxIterations)123     public void enableCompilationCachingBenchmarks(
124             float warmupTimeSeconds, float runTimeSeconds, int maxIterations) {
125         mBenchmarkCompilationCaching = true;
126         mCompilationBenchmarkWarmupTimeSeconds = warmupTimeSeconds;
127         mCompilationBenchmarkRunTimeSeconds = runTimeSeconds;
128         mCompilationBenchmarkMaxIterations = maxIterations;
129     }
130 
131     // Method to retrieve benchmark results for instrumentation tests.
132     // Returns null if the processor is configured to run compilation only
getInstrumentationResult( TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)133     public BenchmarkResult getInstrumentationResult(
134             TestModels.TestModelEntry t, float warmupTimeSeconds, float runTimeSeconds)
135             throws IOException, BenchmarkException {
136         mTest = changeTest(mTest, t);
137         try {
138             BenchmarkResult result = mRunModelCompilationOnly ? null : getBenchmark(warmupTimeSeconds,
139                     runTimeSeconds);
140             return result;
141         } finally {
142             mTest.destroy();
143             mTest = null;
144         }
145     }
146 
isTestModelSupportedByAccelerator(Context context, TestModels.TestModelEntry testModelEntry, String acceleratorName)147     public static boolean isTestModelSupportedByAccelerator(Context context,
148             TestModels.TestModelEntry testModelEntry, String acceleratorName)
149             throws NnApiDelegationFailure {
150         try (NNTestBase tb = testModelEntry.createNNTestBase(/*useNnnapi=*/ true,
151                 /*enableIntermediateTensorsDump=*/false,
152                 /*mmapModel=*/ false)) {
153             tb.setNNApiDeviceName(acceleratorName);
154             return tb.setupModel(context);
155         } catch (IOException e) {
156             Log.w(TAG,
157                     String.format("Error trying to check support for model %s on accelerator %s",
158                             testModelEntry.mModelName, acceleratorName), e);
159             return false;
160         } catch (NnApiDelegationFailure nnApiDelegationFailure) {
161             if (nnApiDelegationFailure.getNnApiErrno() == 4 /*ANEURALNETWORKS_BAD_DATA*/) {
162                 // Compilation will fail with ANEURALNETWORKS_BAD_DATA if the device is not
163                 // supporting all operation in the model
164                 return false;
165             }
166 
167             throw nnApiDelegationFailure;
168         }
169     }
170 
changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)171     private NNTestBase changeTest(NNTestBase oldTestBase, TestModels.TestModelEntry t)
172             throws IOException, UnsupportedModelException, NnApiDelegationFailure {
173         if (oldTestBase != null) {
174             // Make sure we don't leak memory.
175             oldTestBase.destroy();
176         }
177         NNTestBase tb = t.createNNTestBase(mUseNNApi, /*enableIntermediateTensorsDump=*/false,
178                 mMmapModel);
179         if (mUseNNApi) {
180             tb.setNNApiDeviceName(mAcceleratorName);
181         }
182         if (!tb.setupModel(mContext)) {
183             throw new UnsupportedModelException("Cannot initialise model");
184         }
185         return tb;
186     }
187 
188     // Run one loop of kernels for at least the specified minimum time.
189     // The function returns the average time in ms for the test run
runBenchmarkLoop(float minTime, boolean completeInputSet)190     private BenchmarkResult runBenchmarkLoop(float minTime, boolean completeInputSet)
191             throws IOException {
192         try {
193             // Run the kernel
194             Pair<List<InferenceInOutSequence>, List<InferenceResult>> results;
195             if (minTime > 0.f) {
196                 if (completeInputSet) {
197                     results = mTest.runBenchmarkCompleteInputSet(1, minTime);
198                 } else {
199                     results = mTest.runBenchmark(minTime);
200                 }
201             } else {
202                 results = mTest.runInferenceOnce();
203             }
204             return BenchmarkResult.fromInferenceResults(
205                     mTest.getTestInfo(),
206                     mUseNNApi
207                             ? BenchmarkResult.BACKEND_TFLITE_NNAPI
208                             : BenchmarkResult.BACKEND_TFLITE_CPU,
209                     results.first,
210                     results.second,
211                     mTest.getEvaluator());
212         } catch (BenchmarkException e) {
213             return new BenchmarkResult(e.getMessage());
214         }
215     }
216 
217     // Run one loop of compilations for at least the specified minimum time.
218     // The function will set the compilation results into the provided benchmark result object.
runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime, int maxIterations, BenchmarkResult benchmarkResult)219     private void runCompilationBenchmarkLoop(float warmupMinTime, float runMinTime,
220             int maxIterations, BenchmarkResult benchmarkResult) throws IOException {
221         try {
222             CompilationBenchmarkResult result =
223                     mTest.runCompilationBenchmark(warmupMinTime, runMinTime, maxIterations);
224             benchmarkResult.setCompilationBenchmarkResult(result);
225         } catch (BenchmarkException e) {
226             benchmarkResult.setBenchmarkError(e.getMessage());
227         }
228     }
229 
getTestResults()230     public BenchmarkResult[] getTestResults() {
231         return mTestResults;
232     }
233 
234     // Get a benchmark result for a specific test
getBenchmark(float warmupTimeSeconds, float runTimeSeconds)235     private BenchmarkResult getBenchmark(float warmupTimeSeconds, float runTimeSeconds)
236             throws IOException {
237         try {
238             mTest.checkSdkVersion();
239         } catch (UnsupportedSdkException e) {
240             BenchmarkResult r = new BenchmarkResult(e.getMessage());
241             Log.w(TAG, "Unsupported SDK for test: " + r.toString());
242             return r;
243         }
244 
245         // We run a short bit of work before starting the actual test
246         // this is to let any power management do its job and respond.
247         // For NNAPI systrace usage documentation, see
248         // frameworks/ml/nn/common/include/Tracing.h.
249         try {
250             final String traceName = "[NN_LA_PWU]runBenchmarkLoop";
251             Trace.beginSection(traceName);
252             runBenchmarkLoop(warmupTimeSeconds, false);
253         } finally {
254             Trace.endSection();
255         }
256 
257         // Run the actual benchmark
258         BenchmarkResult r;
259         try {
260             final String traceName = "[NN_LA_PBM]runBenchmarkLoop";
261             Trace.beginSection(traceName);
262             r = runBenchmarkLoop(runTimeSeconds, mCompleteInputSet);
263         } finally {
264             Trace.endSection();
265         }
266 
267         // Compilation benchmark
268         if (mUseNNApi && mBenchmarkCompilationCaching) {
269             runCompilationBenchmarkLoop(mCompilationBenchmarkWarmupTimeSeconds,
270                     mCompilationBenchmarkRunTimeSeconds, mCompilationBenchmarkMaxIterations, r);
271         }
272 
273         return r;
274     }
275 
276     @Override
run()277     public void run() {
278         mHasBeenStarted = true;
279         Log.d(TAG, "Processor starting");
280         boolean success = true;
281         int benchmarkIterationsCount = 0;
282         try {
283             while (mRun.get()) {
284                 if (mMaxRunIterations > 0 && benchmarkIterationsCount >= mMaxRunIterations) {
285                     break;
286                 }
287                 benchmarkIterationsCount++;
288                 try {
289                     benchmarkAllModels();
290                 } catch (IOException | BenchmarkException e) {
291                     Log.e(TAG, "Exception during benchmark run", e);
292                     success = false;
293                     break;
294                 } catch (Throwable e) {
295                     Log.e(TAG, "Error during execution", e);
296                     throw e;
297                 }
298             }
299             Log.d(TAG, "Processor completed work");
300             mCallback.onBenchmarkFinish(success);
301         } finally {
302             if (mTest != null) {
303                 // Make sure we don't leak memory.
304                 mTest.destroy();
305                 mTest = null;
306             }
307             mCompleted.countDown();
308         }
309     }
310 
benchmarkAllModels()311     private void benchmarkAllModels() throws IOException, BenchmarkException {
312         // Loop over the tests we want to benchmark
313         for (int ct = 0; ct < mTestList.length; ct++) {
314             if (!mRun.get()) {
315                 Log.v(TAG, String.format("Asked to stop execution at model #%d", ct));
316                 break;
317             }
318             // For reproducibility we wait a short time for any sporadic work
319             // created by the user touching the screen to launch the test to pass.
320             // Also allows for things to settle after the test changes.
321             try {
322                 Thread.sleep(250);
323             } catch (InterruptedException ignored) {
324                 Thread.currentThread().interrupt();
325                 break;
326             }
327 
328             TestModels.TestModelEntry testModel =
329                     TestModels.modelsList().get(mTestList[ct]);
330 
331             int testNumber = ct + 1;
332             mCallback.onStatusUpdate(testNumber, mTestList.length,
333                     testModel.toString());
334 
335             // Select the next test
336             try {
337                 mTest = changeTest(mTest, testModel);
338             } catch (UnsupportedModelException e) {
339                 if (mIgnoreUnsupportedModels) {
340                     Log.d(TAG, String.format(
341                             "Cannot initialise test %d: '%s' on accelerator %s, skipping", ct,
342                             testModel.mTestName, mAcceleratorName));
343                 } else {
344                     Log.e(TAG,
345                             String.format("Cannot initialise test %d: '%s'  on accelerator %s.", ct,
346                                     testModel.mTestName, mAcceleratorName), e);
347                     throw e;
348                 }
349             }
350 
351             // If the user selected the "long pause" option, wait
352             if (mTogglePause) {
353                 for (int i = 0; (i < 100) && mRun.get(); i++) {
354                     try {
355                         Thread.sleep(100);
356                     } catch (InterruptedException ignored) {
357                         Thread.currentThread().interrupt();
358                         break;
359                     }
360                 }
361             }
362 
363             if (mRunModelCompilationOnly) {
364                 mTestResults[ct] = BenchmarkResult.fromInferenceResults(testModel.mTestName,
365                         mUseNNApi
366                                 ? BenchmarkResult.BACKEND_TFLITE_NNAPI
367                                 : BenchmarkResult.BACKEND_TFLITE_CPU, Collections.emptyList(),
368                         Collections.emptyList(), null);
369             } else {
370                 // Run the test
371                 float warmupTime = 0.3f;
372                 float runTime = 1.f;
373                 if (mToggleLong) {
374                     warmupTime = 2.f;
375                     runTime = 10.f;
376                 }
377                 mTestResults[ct] = getBenchmark(warmupTime, runTime);
378             }
379         }
380     }
381 
exit()382     public void exit() {
383         exitWithTimeout(-1l);
384     }
385 
exitWithTimeout(long timeoutMs)386     public void exitWithTimeout(long timeoutMs) {
387         mRun.set(false);
388 
389         if (mHasBeenStarted) {
390             Log.d(TAG, String.format("Terminating, timeout is %d ms", timeoutMs));
391             try {
392                 if (timeoutMs > 0) {
393                     boolean hasCompleted = mCompleted.await(timeoutMs, MILLISECONDS);
394                     if (!hasCompleted) {
395                         Log.w(TAG, "Exiting before execution actually completed");
396                     }
397                 } else {
398                     mCompleted.await();
399                 }
400             } catch (InterruptedException e) {
401                 Thread.currentThread().interrupt();
402                 Log.w(TAG, "Interrupted while waiting for Processor to complete", e);
403             }
404         }
405 
406         Log.d(TAG, "Done, cleaning up");
407 
408         if (mTest != null) {
409             mTest.destroy();
410             mTest = null;
411         }
412     }
413 }
414