1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.core.test;
18 
19 import static java.util.concurrent.TimeUnit.MILLISECONDS;
20 
21 import android.annotation.SuppressLint;
22 import android.content.Context;
23 import android.content.Intent;
24 import android.util.Log;
25 
26 import com.android.nn.benchmark.core.Processor;
27 import com.android.nn.crashtest.core.CrashTest;
28 import com.android.nn.crashtest.core.CrashTestCoordinator.CrashTestIntentInitializer;
29 
30 import java.time.Duration;
31 import java.util.ArrayList;
32 import java.util.Collections;
33 import java.util.HashSet;
34 import java.util.List;
35 import java.util.Optional;
36 import java.util.Set;
37 import java.util.concurrent.CountDownLatch;
38 import java.util.concurrent.ExecutionException;
39 import java.util.concurrent.ExecutorService;
40 import java.util.concurrent.Executors;
41 import java.util.concurrent.Future;
42 
43 public class RunModelsInParallel implements CrashTest {
44 
45     private static final String MODELS = "models";
46     private static final String DURATION = "duration";
47     private static final String THREADS = "thread_counts";
48     private static final String TEST_NAME = "test_name";
49     private static final String ACCELERATOR_NAME = "accelerator_name";
50     private static final String IGNORE_UNSUPPORTED_MODELS = "ignore_unsupported_models";
51     private static final String RUN_MODEL_COMPILATION_ONLY = "run_model_compilation_only";
52     private static final String MEMORY_MAP_MODEL = "memory_map_model";
53 
54     private final Set<Processor> activeTests = new HashSet<>();
55     private final List<Boolean> mTestCompletionResults = Collections.synchronizedList(
56             new ArrayList<>());
57     private long mTestDurationMillis = 0;
58     private int mThreadCount = 0;
59     private int[] mTestList = new int[0];
60     private String mTestName;
61     private String mAcceleratorName;
62     private boolean mIgnoreUnsupportedModels;
63     private Context mContext;
64     private boolean mRunModelCompilationOnly;
65     private ExecutorService mExecutorService = null;
66     private CountDownLatch mParallelTestComplete;
67     private ProgressListener mProgressListener;
68     private boolean mMmapModel;
69 
intentInitializer(int[] models, int threadCount, Duration duration, String testName, String acceleratorName, boolean ignoreUnsupportedModels, boolean runModelCompilationOnly, boolean mmapModel)70     static public CrashTestIntentInitializer intentInitializer(int[] models, int threadCount,
71             Duration duration, String testName, String acceleratorName,
72             boolean ignoreUnsupportedModels,
73             boolean runModelCompilationOnly, boolean mmapModel) {
74         return intent -> {
75             intent.putExtra(MODELS, models);
76             intent.putExtra(DURATION, duration.toMillis());
77             intent.putExtra(THREADS, threadCount);
78             intent.putExtra(TEST_NAME, testName);
79             intent.putExtra(ACCELERATOR_NAME, acceleratorName);
80             intent.putExtra(IGNORE_UNSUPPORTED_MODELS, ignoreUnsupportedModels);
81             intent.putExtra(RUN_MODEL_COMPILATION_ONLY, runModelCompilationOnly);
82             intent.putExtra(MEMORY_MAP_MODEL, mmapModel);
83         };
84     }
85 
86     @Override
init(Context context, Intent configParams, Optional<ProgressListener> progressListener)87     public void init(Context context, Intent configParams,
88             Optional<ProgressListener> progressListener) {
89         mTestList = configParams.getIntArrayExtra(MODELS);
90         mThreadCount = configParams.getIntExtra(THREADS, 10);
91         mTestDurationMillis = configParams.getLongExtra(DURATION, 1000 * 60 * 10);
92         mTestName = configParams.getStringExtra(TEST_NAME);
93         mAcceleratorName = configParams.getStringExtra(ACCELERATOR_NAME);
94         mIgnoreUnsupportedModels = mAcceleratorName != null && configParams.getBooleanExtra(
95                 IGNORE_UNSUPPORTED_MODELS, false);
96         mRunModelCompilationOnly = configParams.getBooleanExtra(RUN_MODEL_COMPILATION_ONLY, false);
97         mMmapModel = configParams.getBooleanExtra(MEMORY_MAP_MODEL, false);
98         mContext = context;
99         mProgressListener = progressListener.orElseGet(() -> (Optional<String> message) -> {
100             Log.v(CrashTest.TAG, message.orElse("."));
101         });
102         mExecutorService = Executors.newFixedThreadPool(mThreadCount);
103         mTestCompletionResults.clear();
104     }
105 
106     @Override
call()107     public Optional<String> call() {
108         mParallelTestComplete = new CountDownLatch(mThreadCount);
109         for (int i = 0; i < mThreadCount; i++) {
110             Processor testProcessor = createSubTestRunner(mTestList, i);
111 
112             activeTests.add(testProcessor);
113             mExecutorService.submit(testProcessor);
114         }
115 
116         return completedSuccessfully();
117     }
118 
createSubTestRunner(final int[] testList, final int testIndex)119     private Processor createSubTestRunner(final int[] testList, final int testIndex) {
120         final Processor result = new Processor(mContext, new Processor.Callback() {
121             @SuppressLint("DefaultLocale")
122             @Override
123             public void onBenchmarkFinish(boolean ok) {
124                 notifyProgress("Test '%s': Benchmark #%d completed %s", mTestName, testIndex,
125                         ok ? "successfully" : "with failure");
126                 mTestCompletionResults.add(ok);
127                 mParallelTestComplete.countDown();
128             }
129 
130             @Override
131             public void onStatusUpdate(int testNumber, int numTests, String modelName) {
132             }
133         }, testList);
134         result.setUseNNApi(true);
135         result.setCompleteInputSet(false);
136         result.setNnApiAcceleratorName(mAcceleratorName);
137         result.setIgnoreUnsupportedModels(mIgnoreUnsupportedModels);
138         result.setRunModelCompilationOnly(mRunModelCompilationOnly);
139         result.setMmapModel(mMmapModel);
140         return result;
141     }
142 
endTests()143     private void endTests() {
144         ExecutorService terminatorsThreadPool = Executors.newFixedThreadPool(activeTests.size());
145         List<Future<?>> terminationCommands = new ArrayList<>();
146         for (final Processor test : activeTests) {
147             // Exit will block until the thread is completed
148             terminationCommands.add(terminatorsThreadPool.submit(
149                     () -> test.exitWithTimeout(Duration.ofSeconds(20).toMillis())));
150         }
151         terminationCommands.forEach(terminationCommand -> {
152             try {
153                 terminationCommand.get();
154             } catch (ExecutionException e) {
155                 Log.w(TAG, "Failure while waiting for completion of tests", e);
156             } catch (InterruptedException e) {
157                 Thread.interrupted();
158             }
159         });
160     }
161 
162     @SuppressLint("DefaultLocale")
notifyProgress(String messageFormat, Object... args)163     void notifyProgress(String messageFormat, Object... args) {
164         mProgressListener.testProgress(Optional.of(String.format(messageFormat, args)));
165     }
166 
167     // This method blocks until the tests complete and returns true if all tests completed
168     // successfully
169     @SuppressLint("DefaultLocale")
completedSuccessfully()170     private Optional<String> completedSuccessfully() {
171         try {
172             boolean testsEnded = mParallelTestComplete.await(mTestDurationMillis, MILLISECONDS);
173             if (!testsEnded) {
174                 Log.i(TAG,
175                         String.format(
176                                 "Test '%s': Tests are not completed (they might have been "
177                                         + "designed to run "
178                                         + "indefinitely. Forcing termination.", mTestName));
179                 endTests();
180             }
181         } catch (InterruptedException ignored) {
182             Thread.currentThread().interrupt();
183         }
184 
185         final long failedTestCount = mTestCompletionResults.stream().filter(
186                 testResult -> !testResult).count();
187         if (failedTestCount > 0) {
188             String failureMsg = String.format("Test '%s': %d out of %d test failed", mTestName,
189                     failedTestCount,
190                     mTestCompletionResults.size());
191             Log.w(CrashTest.TAG, failureMsg);
192             return failure(failureMsg);
193         } else {
194             Log.i(CrashTest.TAG,
195                     String.format("Test '%s': Test completed successfully", mTestName));
196             return success();
197         }
198     }
199 }
200