1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest.core.test; 18 19 import static java.util.concurrent.TimeUnit.MILLISECONDS; 20 21 import android.annotation.SuppressLint; 22 import android.content.Context; 23 import android.content.Intent; 24 import android.util.Log; 25 26 import com.android.nn.benchmark.core.Processor; 27 import com.android.nn.crashtest.core.CrashTest; 28 import com.android.nn.crashtest.core.CrashTestCoordinator.CrashTestIntentInitializer; 29 30 import java.time.Duration; 31 import java.util.ArrayList; 32 import java.util.Collections; 33 import java.util.HashSet; 34 import java.util.List; 35 import java.util.Optional; 36 import java.util.Set; 37 import java.util.concurrent.CountDownLatch; 38 import java.util.concurrent.ExecutionException; 39 import java.util.concurrent.ExecutorService; 40 import java.util.concurrent.Executors; 41 import java.util.concurrent.Future; 42 43 public class RunModelsInParallel implements CrashTest { 44 45 private static final String MODELS = "models"; 46 private static final String DURATION = "duration"; 47 private static final String THREADS = "thread_counts"; 48 private static final String TEST_NAME = "test_name"; 49 private static final String ACCELERATOR_NAME = "accelerator_name"; 50 private static final String IGNORE_UNSUPPORTED_MODELS = "ignore_unsupported_models"; 51 private static final String RUN_MODEL_COMPILATION_ONLY = "run_model_compilation_only"; 52 private static final String MEMORY_MAP_MODEL = "memory_map_model"; 53 54 private final Set<Processor> activeTests = new HashSet<>(); 55 private final List<Boolean> mTestCompletionResults = Collections.synchronizedList( 56 new ArrayList<>()); 57 private long mTestDurationMillis = 0; 58 private int mThreadCount = 0; 59 private int[] mTestList = new int[0]; 60 private String mTestName; 61 private String mAcceleratorName; 62 private boolean mIgnoreUnsupportedModels; 63 private Context mContext; 64 private boolean mRunModelCompilationOnly; 65 private ExecutorService mExecutorService = null; 66 private CountDownLatch mParallelTestComplete; 67 private ProgressListener mProgressListener; 68 private boolean mMmapModel; 69 intentInitializer(int[] models, int threadCount, Duration duration, String testName, String acceleratorName, boolean ignoreUnsupportedModels, boolean runModelCompilationOnly, boolean mmapModel)70 static public CrashTestIntentInitializer intentInitializer(int[] models, int threadCount, 71 Duration duration, String testName, String acceleratorName, 72 boolean ignoreUnsupportedModels, 73 boolean runModelCompilationOnly, boolean mmapModel) { 74 return intent -> { 75 intent.putExtra(MODELS, models); 76 intent.putExtra(DURATION, duration.toMillis()); 77 intent.putExtra(THREADS, threadCount); 78 intent.putExtra(TEST_NAME, testName); 79 intent.putExtra(ACCELERATOR_NAME, acceleratorName); 80 intent.putExtra(IGNORE_UNSUPPORTED_MODELS, ignoreUnsupportedModels); 81 intent.putExtra(RUN_MODEL_COMPILATION_ONLY, runModelCompilationOnly); 82 intent.putExtra(MEMORY_MAP_MODEL, mmapModel); 83 }; 84 } 85 86 @Override init(Context context, Intent configParams, Optional<ProgressListener> progressListener)87 public void init(Context context, Intent configParams, 88 Optional<ProgressListener> progressListener) { 89 mTestList = configParams.getIntArrayExtra(MODELS); 90 mThreadCount = configParams.getIntExtra(THREADS, 10); 91 mTestDurationMillis = configParams.getLongExtra(DURATION, 1000 * 60 * 10); 92 mTestName = configParams.getStringExtra(TEST_NAME); 93 mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME); 94 mIgnoreUnsupportedModels = mAcceleratorName != null && configParams.getBooleanExtra( 95 IGNORE_UNSUPPORTED_MODELS, false); 96 mRunModelCompilationOnly = configParams.getBooleanExtra(RUN_MODEL_COMPILATION_ONLY, false); 97 mMmapModel = configParams.getBooleanExtra(MEMORY_MAP_MODEL, false); 98 mContext = context; 99 mProgressListener = progressListener.orElseGet(() -> (Optional<String> message) -> { 100 Log.v(CrashTest.TAG, message.orElse(".")); 101 }); 102 mExecutorService = Executors.newFixedThreadPool(mThreadCount); 103 mTestCompletionResults.clear(); 104 } 105 106 @Override call()107 public Optional<String> call() { 108 mParallelTestComplete = new CountDownLatch(mThreadCount); 109 for (int i = 0; i < mThreadCount; i++) { 110 Processor testProcessor = createSubTestRunner(mTestList, i); 111 112 activeTests.add(testProcessor); 113 mExecutorService.submit(testProcessor); 114 } 115 116 return completedSuccessfully(); 117 } 118 createSubTestRunner(final int[] testList, final int testIndex)119 private Processor createSubTestRunner(final int[] testList, final int testIndex) { 120 final Processor result = new Processor(mContext, new Processor.Callback() { 121 @SuppressLint("DefaultLocale") 122 @Override 123 public void onBenchmarkFinish(boolean ok) { 124 notifyProgress("Test '%s': Benchmark #%d completed %s", mTestName, testIndex, 125 ok ? "successfully" : "with failure"); 126 mTestCompletionResults.add(ok); 127 mParallelTestComplete.countDown(); 128 } 129 130 @Override 131 public void onStatusUpdate(int testNumber, int numTests, String modelName) { 132 } 133 }, testList); 134 result.setUseNNApi(true); 135 result.setCompleteInputSet(false); 136 result.setNnApiAcceleratorName(mAcceleratorName); 137 result.setIgnoreUnsupportedModels(mIgnoreUnsupportedModels); 138 result.setRunModelCompilationOnly(mRunModelCompilationOnly); 139 result.setMmapModel(mMmapModel); 140 return result; 141 } 142 endTests()143 private void endTests() { 144 ExecutorService terminatorsThreadPool = Executors.newFixedThreadPool(activeTests.size()); 145 List<Future<?>> terminationCommands = new ArrayList<>(); 146 for (final Processor test : activeTests) { 147 // Exit will block until the thread is completed 148 terminationCommands.add(terminatorsThreadPool.submit( 149 () -> test.exitWithTimeout(Duration.ofSeconds(20).toMillis()))); 150 } 151 terminationCommands.forEach(terminationCommand -> { 152 try { 153 terminationCommand.get(); 154 } catch (ExecutionException e) { 155 Log.w(TAG, "Failure while waiting for completion of tests", e); 156 } catch (InterruptedException e) { 157 Thread.interrupted(); 158 } 159 }); 160 } 161 162 @SuppressLint("DefaultLocale") notifyProgress(String messageFormat, Object... args)163 void notifyProgress(String messageFormat, Object... args) { 164 mProgressListener.testProgress(Optional.of(String.format(messageFormat, args))); 165 } 166 167 // This method blocks until the tests complete and returns true if all tests completed 168 // successfully 169 @SuppressLint("DefaultLocale") completedSuccessfully()170 private Optional<String> completedSuccessfully() { 171 try { 172 boolean testsEnded = mParallelTestComplete.await(mTestDurationMillis, MILLISECONDS); 173 if (!testsEnded) { 174 Log.i(TAG, 175 String.format( 176 "Test '%s': Tests are not completed (they might have been " 177 + "designed to run " 178 + "indefinitely. Forcing termination.", mTestName)); 179 endTests(); 180 } 181 } catch (InterruptedException ignored) { 182 Thread.currentThread().interrupt(); 183 } 184 185 final long failedTestCount = mTestCompletionResults.stream().filter( 186 testResult -> !testResult).count(); 187 if (failedTestCount > 0) { 188 String failureMsg = String.format("Test '%s': %d out of %d test failed", mTestName, 189 failedTestCount, 190 mTestCompletionResults.size()); 191 Log.w(CrashTest.TAG, failureMsg); 192 return failure(failureMsg); 193 } else { 194 Log.i(CrashTest.TAG, 195 String.format("Test '%s': Test completed successfully", mTestName)); 196 return success(); 197 } 198 } 199 } 200