1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.crashtest.core.test;
18 
19 import android.annotation.SuppressLint;
20 import android.content.Context;
21 import android.content.Intent;
22 import android.util.Log;
23 import com.android.nn.benchmark.core.InferenceInOut;
24 import com.android.nn.benchmark.core.InferenceInOutSequence;
25 import com.android.nn.benchmark.core.NNTestBase;
26 import com.android.nn.benchmark.core.TestModels;
27 import com.android.nn.crashtest.core.CrashTest;
28 import com.android.nn.crashtest.core.CrashTestCoordinator;
29 import java.io.File;
30 import java.io.FileOutputStream;
31 import java.io.IOException;
32 import java.nio.ByteBuffer;
33 import java.time.Duration;
34 import java.util.Arrays;
35 import java.util.List;
36 import java.util.Optional;
37 
38 public class RunModelsInMultipleProcesses implements CrashTest {
39   private static final String TAG = "NN_MPROC_STRESS";
40   private static final String NATIVE_PROCESS_CMD = "nn_stress_test";
41 
42   public static final String THREADS = "thread_counts";
43   public static final String PROCESSES = "process_counts";
44   public static final String TEST_NAME = "test_name";
45   public static final String MODEL_NAME = "model_name";
46   public static final String TEST_DURATION = "test_duration";
47   public static final String NNAPI_DEVICE_NAME = "nnapi_device_name";
48   public static final String JUST_COMPILE = "just_compile";
49   public static final String CLIENT_FAILURE_RATE_PERCENT = "client_failure_rate_percent";
50   public static final long DEFAULT_TEST_DURATION = Duration.ofSeconds(60).toMillis();
51   public static final int DEFAULT_PROCESSES = 3;
52   public static final int DEFAULT_THREADS = 1;
53   public static final boolean DEFAULT_JUST_COMPILE = false;
54   public static final int DEFAULT_CLIENT_FAILURE_RATE_PERCENT = 0;
55 
56   private Context mContext;
57   private int mThreadCount;
58   private int mProcessCount;
59   private String mTestName;
60   private TestModels.TestModelEntry mTestModelEntry;
61   private Duration mTestDuration;
62   private String mNnApiDeviceName;
63   private boolean mJustCompileModel;
64   private int mClientFailureRatePercent;
65 
intentInitializer(String testName, String modelName, int processCount, int threadCount, Duration duration, String nnApiDeviceName, boolean justCompileModel, int clientFailureRatePercent)66   static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(String testName,
67       String modelName, int processCount, int threadCount, Duration duration,
68       String nnApiDeviceName, boolean justCompileModel, int clientFailureRatePercent) {
69     return intent -> {
70       intent.putExtra(TEST_NAME, testName);
71       intent.putExtra(MODEL_NAME, modelName);
72       intent.putExtra(PROCESSES, processCount);
73       intent.putExtra(THREADS, threadCount);
74       intent.putExtra(TEST_DURATION, duration.toMillis());
75       intent.putExtra(NNAPI_DEVICE_NAME, nnApiDeviceName);
76       intent.putExtra(JUST_COMPILE, justCompileModel);
77       intent.putExtra(CLIENT_FAILURE_RATE_PERCENT, clientFailureRatePercent);
78     };
79   }
intentInitializer( Intent copyFrom)80   static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(
81             Intent copyFrom) {
82     return intentInitializer(
83             copyFrom.getStringExtra(RunModelsInMultipleProcesses.TEST_NAME),
84             copyFrom.getStringExtra(RunModelsInMultipleProcesses.MODEL_NAME),
85             copyFrom.getIntExtra(RunModelsInMultipleProcesses.PROCESSES, DEFAULT_PROCESSES),
86             copyFrom.getIntExtra(RunModelsInMultipleProcesses.THREADS, DEFAULT_THREADS),
87             Duration.ofMillis(copyFrom.getLongExtra(TEST_DURATION, DEFAULT_TEST_DURATION)),
88             copyFrom.getStringExtra(RunModelsInMultipleProcesses.NNAPI_DEVICE_NAME),
89             copyFrom.getBooleanExtra(RunModelsInMultipleProcesses.JUST_COMPILE,
90                     DEFAULT_JUST_COMPILE),
91             copyFrom.getIntExtra(RunModelsInMultipleProcesses.CLIENT_FAILURE_RATE_PERCENT,
92                     DEFAULT_CLIENT_FAILURE_RATE_PERCENT));
93   }
94 
95   @Override
init( Context context, Intent configParams, Optional<ProgressListener> progressListener)96   public void init(
97       Context context, Intent configParams, Optional<ProgressListener> progressListener) {
98     mContext = context;
99     mTestName = configParams.getStringExtra(TEST_NAME);
100     mTestModelEntry = TestModels.getModelByName(configParams.getStringExtra(MODEL_NAME));
101     mProcessCount = configParams.getIntExtra(PROCESSES, DEFAULT_PROCESSES);
102     mThreadCount = configParams.getIntExtra(THREADS, DEFAULT_THREADS);
103     mTestDuration = Duration.ofMillis(
104         configParams.getLongExtra(TEST_DURATION, DEFAULT_TEST_DURATION));
105     mNnApiDeviceName = configParams.getStringExtra(NNAPI_DEVICE_NAME);
106     mJustCompileModel = configParams.getBooleanExtra(JUST_COMPILE, DEFAULT_JUST_COMPILE);
107     mClientFailureRatePercent = configParams.getIntExtra(CLIENT_FAILURE_RATE_PERCENT,
108             DEFAULT_CLIENT_FAILURE_RATE_PERCENT);
109   }
110 
deleteOrWarn(File fileToDelete)111   private void deleteOrWarn(File fileToDelete) {
112     if (fileToDelete.exists()) {
113       if (!fileToDelete.delete()) {
114         Log.w(TAG, String.format("Unable to delete file %s", fileToDelete.getAbsolutePath()));
115       }
116     }
117   }
118 
119   @SuppressLint("DefaultLocale")
120   @Override
call()121   public Optional<String> call() throws Exception {
122     File targetModelFile =
123         new File(mContext.getExternalFilesDir(null), mTestModelEntry.mModelFile + ".tflite");
124     File targetInputFile =
125         new File(mContext.getExternalFilesDir(null), mTestModelEntry.mModelFile + ".input");
126     try {
127       Log.i(TAG,
128           String.format("Trying to create model path '%s'", targetModelFile.getAbsolutePath()));
129       if (!NNTestBase.copyModelToFile(
130               mContext, mTestModelEntry.mModelFile + ".tflite", targetModelFile)) {
131         return failure(String.format("Unable to copy model to target %s file %s",
132             mTestModelEntry.mModelFile, targetModelFile.getAbsolutePath()));
133       }
134       List<InferenceInOutSequence> inputOutputAssets = NNTestBase.getInputOutputAssets(
135           mContext, mTestModelEntry.mInOutAssets, mTestModelEntry.mInOutDatasets);
136 
137       if (!mJustCompileModel) {
138         if (!writeModelInput(targetInputFile, inputOutputAssets, mTestModelEntry.mInputShape,
139                 mTestModelEntry.mInDataSize)) {
140           return failure(String.format("Cannot write test input data file %s for model %s",
141               targetInputFile.getAbsolutePath(), mTestModelEntry.mModelName));
142         }
143       }
144 
145       String inputShapeAsString = Arrays.toString(mTestModelEntry.mInputShape);
146       inputShapeAsString = inputShapeAsString.substring(1, inputShapeAsString.length() - 1);
147       ProcessBuilder multiProcessTestBuilder = new ProcessBuilder();
148       multiProcessTestBuilder.command(NATIVE_PROCESS_CMD, targetModelFile.getAbsolutePath(),
149           mJustCompileModel ? "no-file" : targetInputFile.getAbsolutePath(), inputShapeAsString,
150           "" + mTestModelEntry.mInDataSize, "" + mProcessCount, "" + mThreadCount,
151           "" + mTestDuration.getSeconds(), mTestName, ("" + mJustCompileModel).toLowerCase(),
152           mClientFailureRatePercent + "", mNnApiDeviceName != null ? mNnApiDeviceName : "");
153 
154       Process multiProcessTest = multiProcessTestBuilder.start();
155 
156       int testResult = multiProcessTest.waitFor();
157       Log.i(TAG, String.format("Test process returned %d", testResult));
158       if (testResult == 0) {
159         return success();
160       } else {
161         return failure(String.format("Test failed with return code %d", testResult));
162       }
163     } finally {
164       deleteOrWarn(targetModelFile);
165       deleteOrWarn(targetInputFile);
166     }
167   }
168 
writeModelInput(File targetInputFile, List<InferenceInOutSequence> inputOutputAssets, int[] inputShape, int dataByteSize)169   private boolean writeModelInput(File targetInputFile,
170       List<InferenceInOutSequence> inputOutputAssets, int[] inputShape, int dataByteSize)
171       throws IOException {
172     if (!targetInputFile.exists() && !targetInputFile.createNewFile()) {
173       Log.w(TAG,
174           String.format(
175               "Cannot create test input data file %s", targetInputFile.getAbsolutePath()));
176       return false;
177     }
178 
179     boolean hasContent = false;
180     try (FileOutputStream inputDataWriter = new FileOutputStream(targetInputFile)) {
181       for (InferenceInOutSequence inferenceInOutSequence : inputOutputAssets) {
182         for (int i = 0; i < inferenceInOutSequence.size(); i++) {
183           byte[] input = inferenceInOutSequence.get(i).mInput;
184           final InferenceInOut.InputCreatorInterface creator =
185               inferenceInOutSequence.get(i).mInputCreator;
186           if (input == null && creator != null) {
187             int byteSize = dataByteSize;
188             for (int dimensionSize : inputShape) {
189               byteSize *= dimensionSize;
190             }
191             input = new byte[byteSize];
192             ByteBuffer buffer = ByteBuffer.wrap(input);
193             creator.createInput(buffer);
194           }
195           if (input != null) {
196             hasContent = true;
197             inputDataWriter.write(input);
198           }
199         }
200       }
201     } catch (IOException writeException) {
202       Log.w(TAG, String.format("Cannot write to target file %s", targetInputFile.getAbsolutePath()),
203           writeException);
204       return false;
205     }
206 
207     if (!hasContent) {
208       Log.w(TAG,
209           String.format("No content in inference input sequence to write for file %s",
210               targetInputFile.getAbsolutePath()));
211     }
212 
213     return hasContent;
214   }
215 }