1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.core.test;
18 
19 import android.annotation.SuppressLint;
20 import android.content.Context;
21 import android.content.Intent;
22 import android.util.Log;
23 
24 import com.android.nn.benchmark.core.BenchmarkException;
25 import com.android.nn.benchmark.core.BenchmarkResult;
26 import com.android.nn.benchmark.core.Processor;
27 import com.android.nn.benchmark.core.TestModels;
28 import com.android.nn.crashtest.app.AcceleratorSpecificTestSupport;
29 import com.android.nn.crashtest.core.CrashTest;
30 import com.android.nn.crashtest.core.CrashTestCoordinator;
31 
32 import java.io.IOException;
33 import java.util.Arrays;
34 import java.util.List;
35 import java.util.Optional;
36 import java.util.concurrent.Callable;
37 import java.util.concurrent.CountDownLatch;
38 import java.util.concurrent.ExecutionException;
39 import java.util.concurrent.ExecutorService;
40 import java.util.concurrent.Executors;
41 import java.util.concurrent.Future;
42 import java.util.stream.Stream;
43 
44 public class PerformanceDegradationTest implements CrashTest {
45     public static final String TAG = "NN_PERF_DEG";
46 
47     private static final Processor.Callback mNoOpCallback = new Processor.Callback() {
48         @Override
49         public void onBenchmarkFinish(boolean ok) {
50         }
51 
52         @Override
53         public void onStatusUpdate(int testNumber, int numTests, String modelName) {
54         }
55     };
56 
57     public static final String WARMUP_SECONDS = "warmup_seconds";
58     public static final String RUN_TIME_SECONDS = "run_time_seconds";
59     public static final String ACCELERATOR_NAME = "accelerator_name";
60     public static final float DEFAULT_WARMUP_SECONDS = 3.0f;
61     public static final float DEFAULT_RUN_TIME_SECONDS = 10.0f;
62     public static final String THREAD_COUNT = "thread_count";
63     public static final int DEFAULT_THREAD_COUNT = 5;
64     public static final String MAX_PERFORMANCE_DEGRADATION = "max_performance_degradation";
65     public static final int DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE = 100;
66     public static final String TEST_NAME = "test_name";
67     private static final long INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS = 500;
68 
intentInitializer( float warmupTimeSeconds, float runTimeSeconds, String acceleratorName, int threadCount, int maxPerformanceDegradationPercent, String testName)69     static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(
70             float warmupTimeSeconds, float runTimeSeconds, String acceleratorName, int threadCount,
71             int maxPerformanceDegradationPercent, String testName) {
72         return intent -> {
73             intent.putExtra(WARMUP_SECONDS, warmupTimeSeconds);
74             intent.putExtra(RUN_TIME_SECONDS, runTimeSeconds);
75             intent.putExtra(ACCELERATOR_NAME, acceleratorName);
76             intent.putExtra(THREAD_COUNT, threadCount);
77             intent.putExtra(MAX_PERFORMANCE_DEGRADATION, maxPerformanceDegradationPercent);
78             intent.putExtra(TEST_NAME, testName);
79         };
80     }
81 
intentInitializer( Intent copyFrom)82     static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(
83             Intent copyFrom) {
84         return intentInitializer(
85                 copyFrom.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS),
86                 copyFrom.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS),
87                 copyFrom.getStringExtra(ACCELERATOR_NAME),
88                 copyFrom.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT),
89                 copyFrom.getIntExtra(MAX_PERFORMANCE_DEGRADATION,
90                         DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE),
91                 copyFrom.getStringExtra(TEST_NAME));
92     }
93 
94     private Context mContext;
95     private float mWarmupTimeSeconds;
96     private float mRunTimeSeconds;
97     private String mAcceleratorName;
98     private int mThreadCount;
99     private int mMaxPerformanceDegradationPercent;
100     private String mTestName;
101 
102     @Override
init(Context context, Intent configParams, Optional<ProgressListener> progressListener)103     public void init(Context context, Intent configParams,
104             Optional<ProgressListener> progressListener) {
105         mContext = context;
106 
107         mWarmupTimeSeconds = configParams.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS);
108         mRunTimeSeconds = configParams.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS);
109         mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME);
110         mThreadCount = configParams.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT);
111         mMaxPerformanceDegradationPercent = configParams.getIntExtra(MAX_PERFORMANCE_DEGRADATION,
112                 DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE);
113         mTestName = configParams.getStringExtra(TEST_NAME);
114     }
115 
116     @SuppressLint("DefaultLocale")
117     @Override
call()118     public Optional<String> call() throws Exception {
119         List<TestModels.TestModelEntry> modelsForAccelerator =
120                 AcceleratorSpecificTestSupport.findAllTestModelsRunningOnAccelerator(mContext,
121                         mAcceleratorName);
122 
123         if (modelsForAccelerator.isEmpty()) {
124             return failure("Cannot find any model to use for testing");
125         }
126 
127         Log.i(TAG, String.format("Checking performance degradation using %d models",
128                 modelsForAccelerator.size()));
129 
130         TestModels.TestModelEntry modelForInference = modelsForAccelerator.get(0);
131         // The performance degradation is strongly dependent on the model used to compile
132         // so we check all the available ones.
133         for (TestModels.TestModelEntry modelForCompilation : modelsForAccelerator) {
134             Optional<String> currTestResult = testDegradationForModels(modelForInference,
135                     modelForCompilation);
136             if (isFailure(currTestResult)) {
137                 return currTestResult;
138             }
139         }
140 
141         return success();
142     }
143 
144     @SuppressLint("DefaultLocale")
testDegradationForModels( TestModels.TestModelEntry inferenceModelEntry, TestModels.TestModelEntry compilationModelEntry)145     public Optional<String> testDegradationForModels(
146             TestModels.TestModelEntry inferenceModelEntry,
147             TestModels.TestModelEntry compilationModelEntry) throws Exception {
148         Log.i(TAG, String.format(
149                 "Testing degradation in inference of model %s when running %d threads compliing "
150                         + "model %s",
151                 inferenceModelEntry.mModelName, mThreadCount, compilationModelEntry.mModelName));
152 
153         Log.d(TAG, String.format("%s: Calculating baseline", mTestName));
154         // first let's measure a baseline performance
155         final BenchmarkResult baseline = modelPerformanceCollector(inferenceModelEntry,
156                 /*start=*/ null).call();
157         if (baseline.hasBenchmarkError()) {
158             return failure(String.format("%s: Baseline has benchmark error '%s'",
159                     mTestName, baseline.getBenchmarkError()));
160         }
161         Log.d(TAG, String.format("%s: Baseline mean time is %f seconds", mTestName,
162                 baseline.getMeanTimeSec()));
163 
164         Log.d(TAG, String.format("%s: Sleeping for %d millis", mTestName,
165                 INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS));
166         Thread.sleep(INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS);
167 
168         Log.d(TAG, String.format("%s: Calculating performance with %d threads", mTestName,
169                 mThreadCount));
170         final int totalThreadCount = mThreadCount + 1;
171         final CountDownLatch start = new CountDownLatch(totalThreadCount);
172         ModelCompiler[] compilers = Stream.generate(
173                 () -> new ModelCompiler(start, mContext, mAcceleratorName,
174                         compilationModelEntry)).limit(
175                 mThreadCount).toArray(
176                 ModelCompiler[]::new);
177 
178         Callable<BenchmarkResult> performanceWithOtherCompilingThreadCollector =
179                 modelPerformanceCollector(inferenceModelEntry, start);
180 
181         ExecutorService testExecutor = Executors.newFixedThreadPool(totalThreadCount);
182         Future<?>[] compilerFutures = Arrays.stream(compilers).map(testExecutor::submit).toArray(
183                 Future[]::new);
184         BenchmarkResult benchmarkWithOtherCompilingThread = testExecutor.submit(
185                 performanceWithOtherCompilingThreadCollector).get();
186 
187         Arrays.stream(compilers).forEach(ModelCompiler::stop);
188         Arrays.stream(compilerFutures).forEach(future -> {
189             try {
190                 future.get();
191             } catch (InterruptedException | ExecutionException e) {
192                 Log.e(TAG, "Error waiting for compiler process completion", e);
193             }
194         });
195 
196         if (benchmarkWithOtherCompilingThread.hasBenchmarkError()) {
197             return failure(
198                     String.format(
199                             "%s: Test with parallel compiling thrads has benchmark error '%s'",
200                             mTestName, benchmarkWithOtherCompilingThread.getBenchmarkError()));
201         }
202 
203         Log.d(TAG, String.format("%s: Multithreaded mean time is %f seconds",
204                 mTestName, benchmarkWithOtherCompilingThread.getMeanTimeSec()));
205 
206         int performanceDegradation = (int) (((benchmarkWithOtherCompilingThread.getMeanTimeSec()
207                 / baseline.getMeanTimeSec()) - 1.0) * 100);
208 
209         Log.i(TAG, String.format(
210                 "%s: Performance degradation for accelerator %s, with %d threads is %d%%. "
211                         + "Threshold "
212                         + "is %d%%",
213                 mTestName, mAcceleratorName, mThreadCount, performanceDegradation,
214                 mMaxPerformanceDegradationPercent));
215 
216         if (performanceDegradation > mMaxPerformanceDegradationPercent) {
217             return failure(String.format("Performance degradation is %d%%. Max acceptable is %d%%",
218                     performanceDegradation, mMaxPerformanceDegradationPercent));
219         }
220 
221         return success();
222     }
223 
224 
modelPerformanceCollector( final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start)225     private Callable<BenchmarkResult> modelPerformanceCollector(
226             final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start) {
227         return () -> {
228             Processor benchmarkProcessor = new Processor(mContext, mNoOpCallback, new int[0]);
229             benchmarkProcessor.setUseNNApi(true);
230             benchmarkProcessor.setNnApiAcceleratorName(mAcceleratorName);
231             if (start != null) {
232                 start.countDown();
233                 start.await();
234             }
235             final BenchmarkResult result =
236                     benchmarkProcessor.getInstrumentationResult(
237                             inferenceModelEntry, mWarmupTimeSeconds, mRunTimeSeconds);
238 
239             return result;
240         };
241     }
242 
243     private static class ModelCompiler implements Callable<Void> {
244         private static final long SLEEP_BETWEEN_COMPILATION_INTERVAL_MS = 20;
245         private final CountDownLatch mStart;
246         private final Processor mProcessor;
247         private final TestModels.TestModelEntry mTestModelEntry;
248         private volatile boolean mRun;
249 
250         ModelCompiler(final CountDownLatch start, final Context context,
251                 final String acceleratorName, TestModels.TestModelEntry testModelEntry) {
252             mStart = start;
253             mTestModelEntry = testModelEntry;
254             mProcessor = new Processor(context, mNoOpCallback, new int[0]);
255             mProcessor.setUseNNApi(true);
256             mProcessor.setNnApiAcceleratorName(acceleratorName);
257             mProcessor.setRunModelCompilationOnly(true);
258             mRun = true;
259         }
260 
261         @Override
262         public Void call() throws IOException, BenchmarkException {
263             if (mStart != null) {
264                 try {
265                     mStart.countDown();
266                     mStart.await();
267                 } catch (InterruptedException e) {
268                     Thread.interrupted();
269                     Log.i(TAG, "Interrupted, stopping processing");
270                     return null;
271                 }
272             }
273             while (mRun) {
274                 mProcessor.getInstrumentationResult(mTestModelEntry, 0, 0);
275                 try {
276                     Thread.sleep(SLEEP_BETWEEN_COMPILATION_INTERVAL_MS);
277                 } catch (InterruptedException e) {
278                     Thread.interrupted();
279                     return null;
280                 }
281             }
282             return null;
283         }
284 
285         public void stop() {
286             mRun = false;
287         }
288     }
289 }
290