1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest.core.test; 18 19 import android.annotation.SuppressLint; 20 import android.content.Context; 21 import android.content.Intent; 22 import android.util.Log; 23 24 import com.android.nn.benchmark.core.BenchmarkException; 25 import com.android.nn.benchmark.core.BenchmarkResult; 26 import com.android.nn.benchmark.core.Processor; 27 import com.android.nn.benchmark.core.TestModels; 28 import com.android.nn.crashtest.app.AcceleratorSpecificTestSupport; 29 import com.android.nn.crashtest.core.CrashTest; 30 import com.android.nn.crashtest.core.CrashTestCoordinator; 31 32 import java.io.IOException; 33 import java.util.Arrays; 34 import java.util.List; 35 import java.util.Optional; 36 import java.util.concurrent.Callable; 37 import java.util.concurrent.CountDownLatch; 38 import java.util.concurrent.ExecutionException; 39 import java.util.concurrent.ExecutorService; 40 import java.util.concurrent.Executors; 41 import java.util.concurrent.Future; 42 import java.util.stream.Stream; 43 44 public class PerformanceDegradationTest implements CrashTest { 45 public static final String TAG = "NN_PERF_DEG"; 46 47 private static final Processor.Callback mNoOpCallback = new Processor.Callback() { 48 @Override 49 public void onBenchmarkFinish(boolean ok) { 50 } 51 52 @Override 53 public void onStatusUpdate(int testNumber, int numTests, String modelName) { 54 } 55 }; 56 57 public static final String WARMUP_SECONDS = "warmup_seconds"; 58 public static final String RUN_TIME_SECONDS = "run_time_seconds"; 59 public static final String ACCELERATOR_NAME = "accelerator_name"; 60 public static final float DEFAULT_WARMUP_SECONDS = 3.0f; 61 public static final float DEFAULT_RUN_TIME_SECONDS = 10.0f; 62 public static final String THREAD_COUNT = "thread_count"; 63 public static final int DEFAULT_THREAD_COUNT = 5; 64 public static final String MAX_PERFORMANCE_DEGRADATION = "max_performance_degradation"; 65 public static final int DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE = 100; 66 public static final String TEST_NAME = "test_name"; 67 private static final long INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS = 500; 68 intentInitializer( float warmupTimeSeconds, float runTimeSeconds, String acceleratorName, int threadCount, int maxPerformanceDegradationPercent, String testName)69 static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer( 70 float warmupTimeSeconds, float runTimeSeconds, String acceleratorName, int threadCount, 71 int maxPerformanceDegradationPercent, String testName) { 72 return intent -> { 73 intent.putExtra(WARMUP_SECONDS, warmupTimeSeconds); 74 intent.putExtra(RUN_TIME_SECONDS, runTimeSeconds); 75 intent.putExtra(ACCELERATOR_NAME, acceleratorName); 76 intent.putExtra(THREAD_COUNT, threadCount); 77 intent.putExtra(MAX_PERFORMANCE_DEGRADATION, maxPerformanceDegradationPercent); 78 intent.putExtra(TEST_NAME, testName); 79 }; 80 } 81 intentInitializer( Intent copyFrom)82 static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer( 83 Intent copyFrom) { 84 return intentInitializer( 85 copyFrom.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS), 86 copyFrom.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS), 87 copyFrom.getStringExtra(ACCELERATOR_NAME), 88 copyFrom.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT), 89 copyFrom.getIntExtra(MAX_PERFORMANCE_DEGRADATION, 90 DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE), 91 copyFrom.getStringExtra(TEST_NAME)); 92 } 93 94 private Context mContext; 95 private float mWarmupTimeSeconds; 96 private float mRunTimeSeconds; 97 private String mAcceleratorName; 98 private int mThreadCount; 99 private int mMaxPerformanceDegradationPercent; 100 private String mTestName; 101 102 @Override init(Context context, Intent configParams, Optional<ProgressListener> progressListener)103 public void init(Context context, Intent configParams, 104 Optional<ProgressListener> progressListener) { 105 mContext = context; 106 107 mWarmupTimeSeconds = configParams.getFloatExtra(WARMUP_SECONDS, DEFAULT_WARMUP_SECONDS); 108 mRunTimeSeconds = configParams.getFloatExtra(RUN_TIME_SECONDS, DEFAULT_RUN_TIME_SECONDS); 109 mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME); 110 mThreadCount = configParams.getIntExtra(THREAD_COUNT, DEFAULT_THREAD_COUNT); 111 mMaxPerformanceDegradationPercent = configParams.getIntExtra(MAX_PERFORMANCE_DEGRADATION, 112 DEFAULT_MAX_PERFORMANCE_DEGRADATION_PERCENTAGE); 113 mTestName = configParams.getStringExtra(TEST_NAME); 114 } 115 116 @SuppressLint("DefaultLocale") 117 @Override call()118 public Optional<String> call() throws Exception { 119 List<TestModels.TestModelEntry> modelsForAccelerator = 120 AcceleratorSpecificTestSupport.findAllTestModelsRunningOnAccelerator(mContext, 121 mAcceleratorName); 122 123 if (modelsForAccelerator.isEmpty()) { 124 return failure("Cannot find any model to use for testing"); 125 } 126 127 Log.i(TAG, String.format("Checking performance degradation using %d models", 128 modelsForAccelerator.size())); 129 130 TestModels.TestModelEntry modelForInference = modelsForAccelerator.get(0); 131 // The performance degradation is strongly dependent on the model used to compile 132 // so we check all the available ones. 133 for (TestModels.TestModelEntry modelForCompilation : modelsForAccelerator) { 134 Optional<String> currTestResult = testDegradationForModels(modelForInference, 135 modelForCompilation); 136 if (isFailure(currTestResult)) { 137 return currTestResult; 138 } 139 } 140 141 return success(); 142 } 143 144 @SuppressLint("DefaultLocale") testDegradationForModels( TestModels.TestModelEntry inferenceModelEntry, TestModels.TestModelEntry compilationModelEntry)145 public Optional<String> testDegradationForModels( 146 TestModels.TestModelEntry inferenceModelEntry, 147 TestModels.TestModelEntry compilationModelEntry) throws Exception { 148 Log.i(TAG, String.format( 149 "Testing degradation in inference of model %s when running %d threads compliing " 150 + "model %s", 151 inferenceModelEntry.mModelName, mThreadCount, compilationModelEntry.mModelName)); 152 153 Log.d(TAG, String.format("%s: Calculating baseline", mTestName)); 154 // first let's measure a baseline performance 155 final BenchmarkResult baseline = modelPerformanceCollector(inferenceModelEntry, 156 /*start=*/ null).call(); 157 if (baseline.hasBenchmarkError()) { 158 return failure(String.format("%s: Baseline has benchmark error '%s'", 159 mTestName, baseline.getBenchmarkError())); 160 } 161 Log.d(TAG, String.format("%s: Baseline mean time is %f seconds", mTestName, 162 baseline.getMeanTimeSec())); 163 164 Log.d(TAG, String.format("%s: Sleeping for %d millis", mTestName, 165 INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS)); 166 Thread.sleep(INTERVAL_BETWEEN_PERFORMANCE_MEASUREMENTS_MS); 167 168 Log.d(TAG, String.format("%s: Calculating performance with %d threads", mTestName, 169 mThreadCount)); 170 final int totalThreadCount = mThreadCount + 1; 171 final CountDownLatch start = new CountDownLatch(totalThreadCount); 172 ModelCompiler[] compilers = Stream.generate( 173 () -> new ModelCompiler(start, mContext, mAcceleratorName, 174 compilationModelEntry)).limit( 175 mThreadCount).toArray( 176 ModelCompiler[]::new); 177 178 Callable<BenchmarkResult> performanceWithOtherCompilingThreadCollector = 179 modelPerformanceCollector(inferenceModelEntry, start); 180 181 ExecutorService testExecutor = Executors.newFixedThreadPool(totalThreadCount); 182 Future<?>[] compilerFutures = Arrays.stream(compilers).map(testExecutor::submit).toArray( 183 Future[]::new); 184 BenchmarkResult benchmarkWithOtherCompilingThread = testExecutor.submit( 185 performanceWithOtherCompilingThreadCollector).get(); 186 187 Arrays.stream(compilers).forEach(ModelCompiler::stop); 188 Arrays.stream(compilerFutures).forEach(future -> { 189 try { 190 future.get(); 191 } catch (InterruptedException | ExecutionException e) { 192 Log.e(TAG, "Error waiting for compiler process completion", e); 193 } 194 }); 195 196 if (benchmarkWithOtherCompilingThread.hasBenchmarkError()) { 197 return failure( 198 String.format( 199 "%s: Test with parallel compiling thrads has benchmark error '%s'", 200 mTestName, benchmarkWithOtherCompilingThread.getBenchmarkError())); 201 } 202 203 Log.d(TAG, String.format("%s: Multithreaded mean time is %f seconds", 204 mTestName, benchmarkWithOtherCompilingThread.getMeanTimeSec())); 205 206 int performanceDegradation = (int) (((benchmarkWithOtherCompilingThread.getMeanTimeSec() 207 / baseline.getMeanTimeSec()) - 1.0) * 100); 208 209 Log.i(TAG, String.format( 210 "%s: Performance degradation for accelerator %s, with %d threads is %d%%. " 211 + "Threshold " 212 + "is %d%%", 213 mTestName, mAcceleratorName, mThreadCount, performanceDegradation, 214 mMaxPerformanceDegradationPercent)); 215 216 if (performanceDegradation > mMaxPerformanceDegradationPercent) { 217 return failure(String.format("Performance degradation is %d%%. Max acceptable is %d%%", 218 performanceDegradation, mMaxPerformanceDegradationPercent)); 219 } 220 221 return success(); 222 } 223 224 modelPerformanceCollector( final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start)225 private Callable<BenchmarkResult> modelPerformanceCollector( 226 final TestModels.TestModelEntry inferenceModelEntry, final CountDownLatch start) { 227 return () -> { 228 Processor benchmarkProcessor = new Processor(mContext, mNoOpCallback, new int[0]); 229 benchmarkProcessor.setUseNNApi(true); 230 benchmarkProcessor.setNnApiAcceleratorName(mAcceleratorName); 231 if (start != null) { 232 start.countDown(); 233 start.await(); 234 } 235 final BenchmarkResult result = 236 benchmarkProcessor.getInstrumentationResult( 237 inferenceModelEntry, mWarmupTimeSeconds, mRunTimeSeconds); 238 239 return result; 240 }; 241 } 242 243 private static class ModelCompiler implements Callable<Void> { 244 private static final long SLEEP_BETWEEN_COMPILATION_INTERVAL_MS = 20; 245 private final CountDownLatch mStart; 246 private final Processor mProcessor; 247 private final TestModels.TestModelEntry mTestModelEntry; 248 private volatile boolean mRun; 249 250 ModelCompiler(final CountDownLatch start, final Context context, 251 final String acceleratorName, TestModels.TestModelEntry testModelEntry) { 252 mStart = start; 253 mTestModelEntry = testModelEntry; 254 mProcessor = new Processor(context, mNoOpCallback, new int[0]); 255 mProcessor.setUseNNApi(true); 256 mProcessor.setNnApiAcceleratorName(acceleratorName); 257 mProcessor.setRunModelCompilationOnly(true); 258 mRun = true; 259 } 260 261 @Override 262 public Void call() throws IOException, BenchmarkException { 263 if (mStart != null) { 264 try { 265 mStart.countDown(); 266 mStart.await(); 267 } catch (InterruptedException e) { 268 Thread.interrupted(); 269 Log.i(TAG, "Interrupted, stopping processing"); 270 return null; 271 } 272 } 273 while (mRun) { 274 mProcessor.getInstrumentationResult(mTestModelEntry, 0, 0); 275 try { 276 Thread.sleep(SLEEP_BETWEEN_COMPILATION_INTERVAL_MS); 277 } catch (InterruptedException e) { 278 Thread.interrupted(); 279 return null; 280 } 281 } 282 return null; 283 } 284 285 public void stop() { 286 mRun = false; 287 } 288 } 289 } 290