1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.crashtest.core.test; 18 19 import android.annotation.SuppressLint; 20 import android.content.Context; 21 import android.content.Intent; 22 import android.util.Log; 23 import com.android.nn.benchmark.core.InferenceInOut; 24 import com.android.nn.benchmark.core.InferenceInOutSequence; 25 import com.android.nn.benchmark.core.NNTestBase; 26 import com.android.nn.benchmark.core.TestModels; 27 import com.android.nn.crashtest.core.CrashTest; 28 import com.android.nn.crashtest.core.CrashTestCoordinator; 29 import java.io.File; 30 import java.io.FileOutputStream; 31 import java.io.IOException; 32 import java.nio.ByteBuffer; 33 import java.time.Duration; 34 import java.util.Arrays; 35 import java.util.List; 36 import java.util.Optional; 37 38 public class RunModelsInMultipleProcesses implements CrashTest { 39 private static final String TAG = "NN_MPROC_STRESS"; 40 private static final String NATIVE_PROCESS_CMD = "nn_stress_test"; 41 42 public static final String THREADS = "thread_counts"; 43 public static final String PROCESSES = "process_counts"; 44 public static final String TEST_NAME = "test_name"; 45 public static final String MODEL_NAME = "model_name"; 46 public static final String TEST_DURATION = "test_duration"; 47 public static final String NNAPI_DEVICE_NAME = "nnapi_device_name"; 48 public static final String JUST_COMPILE = "just_compile"; 49 public static final String CLIENT_FAILURE_RATE_PERCENT = "client_failure_rate_percent"; 50 public static final long DEFAULT_TEST_DURATION = Duration.ofSeconds(60).toMillis(); 51 public static final int DEFAULT_PROCESSES = 3; 52 public static final int DEFAULT_THREADS = 1; 53 public static final boolean DEFAULT_JUST_COMPILE = false; 54 public static final int DEFAULT_CLIENT_FAILURE_RATE_PERCENT = 0; 55 56 private Context mContext; 57 private int mThreadCount; 58 private int mProcessCount; 59 private String mTestName; 60 private TestModels.TestModelEntry mTestModelEntry; 61 private Duration mTestDuration; 62 private String mNnApiDeviceName; 63 private boolean mJustCompileModel; 64 private int mClientFailureRatePercent; 65 intentInitializer(String testName, String modelName, int processCount, int threadCount, Duration duration, String nnApiDeviceName, boolean justCompileModel, int clientFailureRatePercent)66 static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer(String testName, 67 String modelName, int processCount, int threadCount, Duration duration, 68 String nnApiDeviceName, boolean justCompileModel, int clientFailureRatePercent) { 69 return intent -> { 70 intent.putExtra(TEST_NAME, testName); 71 intent.putExtra(MODEL_NAME, modelName); 72 intent.putExtra(PROCESSES, processCount); 73 intent.putExtra(THREADS, threadCount); 74 intent.putExtra(TEST_DURATION, duration.toMillis()); 75 intent.putExtra(NNAPI_DEVICE_NAME, nnApiDeviceName); 76 intent.putExtra(JUST_COMPILE, justCompileModel); 77 intent.putExtra(CLIENT_FAILURE_RATE_PERCENT, clientFailureRatePercent); 78 }; 79 } intentInitializer( Intent copyFrom)80 static public CrashTestCoordinator.CrashTestIntentInitializer intentInitializer( 81 Intent copyFrom) { 82 return intentInitializer( 83 copyFrom.getStringExtra(RunModelsInMultipleProcesses.TEST_NAME), 84 copyFrom.getStringExtra(RunModelsInMultipleProcesses.MODEL_NAME), 85 copyFrom.getIntExtra(RunModelsInMultipleProcesses.PROCESSES, DEFAULT_PROCESSES), 86 copyFrom.getIntExtra(RunModelsInMultipleProcesses.THREADS, DEFAULT_THREADS), 87 Duration.ofMillis(copyFrom.getLongExtra(TEST_DURATION, DEFAULT_TEST_DURATION)), 88 copyFrom.getStringExtra(RunModelsInMultipleProcesses.NNAPI_DEVICE_NAME), 89 copyFrom.getBooleanExtra(RunModelsInMultipleProcesses.JUST_COMPILE, 90 DEFAULT_JUST_COMPILE), 91 copyFrom.getIntExtra(RunModelsInMultipleProcesses.CLIENT_FAILURE_RATE_PERCENT, 92 DEFAULT_CLIENT_FAILURE_RATE_PERCENT)); 93 } 94 95 @Override init( Context context, Intent configParams, Optional<ProgressListener> progressListener)96 public void init( 97 Context context, Intent configParams, Optional<ProgressListener> progressListener) { 98 mContext = context; 99 mTestName = configParams.getStringExtra(TEST_NAME); 100 mTestModelEntry = TestModels.getModelByName(configParams.getStringExtra(MODEL_NAME)); 101 mProcessCount = configParams.getIntExtra(PROCESSES, DEFAULT_PROCESSES); 102 mThreadCount = configParams.getIntExtra(THREADS, DEFAULT_THREADS); 103 mTestDuration = Duration.ofMillis( 104 configParams.getLongExtra(TEST_DURATION, DEFAULT_TEST_DURATION)); 105 mNnApiDeviceName = configParams.getStringExtra(NNAPI_DEVICE_NAME); 106 mJustCompileModel = configParams.getBooleanExtra(JUST_COMPILE, DEFAULT_JUST_COMPILE); 107 mClientFailureRatePercent = configParams.getIntExtra(CLIENT_FAILURE_RATE_PERCENT, 108 DEFAULT_CLIENT_FAILURE_RATE_PERCENT); 109 } 110 deleteOrWarn(File fileToDelete)111 private void deleteOrWarn(File fileToDelete) { 112 if (fileToDelete.exists()) { 113 if (!fileToDelete.delete()) { 114 Log.w(TAG, String.format("Unable to delete file %s", fileToDelete.getAbsolutePath())); 115 } 116 } 117 } 118 119 @SuppressLint("DefaultLocale") 120 @Override call()121 public Optional<String> call() throws Exception { 122 File targetModelFile = 123 new File(mContext.getExternalFilesDir(null), mTestModelEntry.mModelFile + ".tflite"); 124 File targetInputFile = 125 new File(mContext.getExternalFilesDir(null), mTestModelEntry.mModelFile + ".input"); 126 try { 127 Log.i(TAG, 128 String.format("Trying to create model path '%s'", targetModelFile.getAbsolutePath())); 129 if (!NNTestBase.copyModelToFile( 130 mContext, mTestModelEntry.mModelFile + ".tflite", targetModelFile)) { 131 return failure(String.format("Unable to copy model to target %s file %s", 132 mTestModelEntry.mModelFile, targetModelFile.getAbsolutePath())); 133 } 134 List<InferenceInOutSequence> inputOutputAssets = NNTestBase.getInputOutputAssets( 135 mContext, mTestModelEntry.mInOutAssets, mTestModelEntry.mInOutDatasets); 136 137 if (!mJustCompileModel) { 138 if (!writeModelInput(targetInputFile, inputOutputAssets, mTestModelEntry.mInputShape, 139 mTestModelEntry.mInDataSize)) { 140 return failure(String.format("Cannot write test input data file %s for model %s", 141 targetInputFile.getAbsolutePath(), mTestModelEntry.mModelName)); 142 } 143 } 144 145 String inputShapeAsString = Arrays.toString(mTestModelEntry.mInputShape); 146 inputShapeAsString = inputShapeAsString.substring(1, inputShapeAsString.length() - 1); 147 ProcessBuilder multiProcessTestBuilder = new ProcessBuilder(); 148 multiProcessTestBuilder.command(NATIVE_PROCESS_CMD, targetModelFile.getAbsolutePath(), 149 mJustCompileModel ? "no-file" : targetInputFile.getAbsolutePath(), inputShapeAsString, 150 "" + mTestModelEntry.mInDataSize, "" + mProcessCount, "" + mThreadCount, 151 "" + mTestDuration.getSeconds(), mTestName, ("" + mJustCompileModel).toLowerCase(), 152 mClientFailureRatePercent + "", mNnApiDeviceName != null ? mNnApiDeviceName : ""); 153 154 Process multiProcessTest = multiProcessTestBuilder.start(); 155 156 int testResult = multiProcessTest.waitFor(); 157 Log.i(TAG, String.format("Test process returned %d", testResult)); 158 if (testResult == 0) { 159 return success(); 160 } else { 161 return failure(String.format("Test failed with return code %d", testResult)); 162 } 163 } finally { 164 deleteOrWarn(targetModelFile); 165 deleteOrWarn(targetInputFile); 166 } 167 } 168 writeModelInput(File targetInputFile, List<InferenceInOutSequence> inputOutputAssets, int[] inputShape, int dataByteSize)169 private boolean writeModelInput(File targetInputFile, 170 List<InferenceInOutSequence> inputOutputAssets, int[] inputShape, int dataByteSize) 171 throws IOException { 172 if (!targetInputFile.exists() && !targetInputFile.createNewFile()) { 173 Log.w(TAG, 174 String.format( 175 "Cannot create test input data file %s", targetInputFile.getAbsolutePath())); 176 return false; 177 } 178 179 boolean hasContent = false; 180 try (FileOutputStream inputDataWriter = new FileOutputStream(targetInputFile)) { 181 for (InferenceInOutSequence inferenceInOutSequence : inputOutputAssets) { 182 for (int i = 0; i < inferenceInOutSequence.size(); i++) { 183 byte[] input = inferenceInOutSequence.get(i).mInput; 184 final InferenceInOut.InputCreatorInterface creator = 185 inferenceInOutSequence.get(i).mInputCreator; 186 if (input == null && creator != null) { 187 int byteSize = dataByteSize; 188 for (int dimensionSize : inputShape) { 189 byteSize *= dimensionSize; 190 } 191 input = new byte[byteSize]; 192 ByteBuffer buffer = ByteBuffer.wrap(input); 193 creator.createInput(buffer); 194 } 195 if (input != null) { 196 hasContent = true; 197 inputDataWriter.write(input); 198 } 199 } 200 } 201 } catch (IOException writeException) { 202 Log.w(TAG, String.format("Cannot write to target file %s", targetInputFile.getAbsolutePath()), 203 writeException); 204 return false; 205 } 206 207 if (!hasContent) { 208 Log.w(TAG, 209 String.format("No content in inference input sequence to write for file %s", 210 targetInputFile.getAbsolutePath())); 211 } 212 213 return hasContent; 214 } 215 }