1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <android/log.h>
17 #include <fcntl.h>
18 #include <jni.h>
19 #include <signal.h>
20 #include <sys/types.h>
21 #include <sys/wait.h>
22 #include <time.h>
23 #include <unistd.h>
24
25 #include <fstream>
26 #include <memory>
27 #include <string>
28 #include <sstream>
29 #include <thread>
30 #include <vector>
31
32 #include "run_tflite.h"
33 #include "tensorflow/lite/interpreter.h"
34 #include "tensorflow/lite/model.h"
35 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
36
37 #define LOG_TAG "NN_MPROC_STRESS"
38
39 constexpr int kInvalidArguments = -1;
40
41 enum Arguments : int {
42 kArgModelPath = 1,
43 kArgInputDataPath,
44 kArgInputShape,
45 kArgInputElementSize,
46 kArgProcessCount,
47 kArgThreadCount,
48 kArgDurationSeconds,
49 kArgTestName,
50 kArgJustCompileModel,
51 kArgProcessFailureRatePercent,
52 kArgNnApiDeviceName,
53 kArgMmapModel
54 };
55
56 constexpr int kMandatoryArgsCount = 9;
57
58 const char* kUsage =
59 R"""(%s modelFileName inputDataFile inputShape inputElementByteSize procCount threadCount durationInSeconds testName justCompileModel [processFailureRate] [nnapiDeviceName] [mmapModel]
60
61 where:
62 inputShape comma separated list of integers (e.g. '1,224,224,3')
63 justCompileModel: true/false)
64 processFailureRate: 0 to 100 percent probability of having one of the client processes failing. Defaults to 0.)
65 mmapModel: true/false select if the TFLite model should be memory mapped to the given file or created from program memory)""";
66
canReadInputFile(const char * path)67 bool canReadInputFile(const char* path) {
68 std::string modelFileName(path);
69 std::ifstream fstream(modelFileName);
70 std::stringstream readBuffer;
71 readBuffer << fstream.rdbuf();
72 return fstream.good();
73 }
74
readInputData(const char * inputDataFileName,std::vector<int> input_shape,int inputElementSize,std::vector<InferenceInOutSequence> * result)75 bool readInputData(const char* inputDataFileName, std::vector<int> input_shape,
76 int inputElementSize,
77 std::vector<InferenceInOutSequence>* result) {
78 int inputElementCount = 1;
79 std::for_each(
80 input_shape.begin(), input_shape.end(),
81 [&inputElementCount](int dimSize) { inputElementCount *= dimSize; });
82 size_t inputDataSizeBytes = inputElementCount * inputElementSize;
83
84 std::ifstream dataFile;
85 dataFile.open(inputDataFileName);
86 if (!dataFile) {
87 return false;
88 }
89
90 std::function<bool(uint8_t*, size_t)> failToGenerateData =
91 [](uint8_t*, size_t) { return false; };
92 while (!dataFile.eof()) {
93 std::unique_ptr<uint8_t[]> dataBuffer =
94 std::make_unique<uint8_t[]>(inputDataSizeBytes);
95 if (!dataFile.read(reinterpret_cast<char*>(dataBuffer.get()),
96 inputDataSizeBytes)) {
97 break;
98 }
99 InferenceInOut entry{
100 dataBuffer.release(), inputDataSizeBytes, {}, failToGenerateData};
101 result->push_back({entry});
102 }
103
104 return result;
105 }
106
runModel(const char * modelFileName,const std::vector<InferenceInOutSequence> & data,int durationSeconds,const std::string & nnApiDeviceName,bool justCompileModel,bool mmapModel)107 bool runModel(const char* modelFileName,
108 const std::vector<InferenceInOutSequence>& data,
109 int durationSeconds, const std::string& nnApiDeviceName,
110 bool justCompileModel, bool mmapModel) {
111 if (justCompileModel) {
112 std::time_t startTime = std::time(nullptr);
113 while (std::difftime(std::time(nullptr), startTime) < durationSeconds) {
114 int nnapiErrno = 0;
115 std::unique_ptr<BenchmarkModel> model(BenchmarkModel::create(
116 modelFileName, /*useNnApi=*/true,
117 /*enableIntermediateTensorsDump=*/false,
118 &nnapiErrno,
119 nnApiDeviceName.empty() ? nullptr : nnApiDeviceName.c_str(), mmapModel,
120 /*nnapi_cache_dir=*/nullptr));
121
122 if (!model) {
123 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Error creating model");
124 return false;
125 }
126
127 // sleeping from 300ms to 800ms
128 constexpr int kMinPauseMs = 300;
129 constexpr int kMaxPauseMs = 800;
130 int sleepForMs = kMinPauseMs + (drand48() * (kMaxPauseMs - kMinPauseMs));
131 usleep(sleepForMs * 1000);
132 }
133
134 return true;
135 } else {
136 int nnapiErrno = 0;
137 std::unique_ptr<BenchmarkModel> model(BenchmarkModel::create(
138 modelFileName, /*useNnApi=*/true,
139 /*enableIntermediateTensorsDump=*/false,
140 &nnapiErrno,
141 nnApiDeviceName.empty() ? nullptr : nnApiDeviceName.c_str(), mmapModel,
142 /*nnapi_cache_dir=*/nullptr));
143
144 if (!model) {
145 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Error creating model");
146 return false;
147 }
148
149 std::vector<InferenceResult> result;
150 constexpr int flags =
151 FLAG_DISCARD_INFERENCE_OUTPUT | FLAG_IGNORE_GOLDEN_OUTPUT;
152 return model->benchmark(data, std::numeric_limits<int>::max(),
153 durationSeconds, flags, &result);
154 }
155 }
156
getBooleanArg(int argc,char * argv[],int argIndex,bool defaultValue)157 bool getBooleanArg(int argc, char* argv[], int argIndex, bool defaultValue) {
158 if (argc > argIndex) {
159 std::string argAsString(argv[argIndex]);
160 return argAsString == "true";
161 } else {
162 return defaultValue;
163 }
164 }
165
getIntArg(int argc,char * argv[],int argIndex,int defaultValue)166 int getIntArg(int argc, char* argv[], int argIndex, int defaultValue) {
167 if (argc > argIndex) {
168 return std::atoi(argv[argIndex]);
169 } else {
170 return defaultValue;
171 }
172 }
173
main(int argc,char * argv[])174 int main(int argc, char* argv[]) {
175 if (argc < kMandatoryArgsCount) {
176 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, kUsage, kMandatoryArgsCount,
177 argc, argv[0]);
178 return kInvalidArguments;
179 }
180
181 const char* modelFileName = argv[kArgModelPath];
182 const char* inputDataFileName = argv[kArgInputDataPath];
183 const char* testName = argv[kArgTestName];
184 std::string nnApiDeviceName{
185 argc > kArgNnApiDeviceName ? argv[kArgNnApiDeviceName] : ""};
186 int numProcesses = getIntArg(argc, argv, kArgProcessCount, 0);
187 int numThreads = getIntArg(argc, argv, kArgThreadCount, 0);
188 int durationSeconds = getIntArg(argc, argv, kArgDurationSeconds, 0);
189 bool justCompileModel =
190 getBooleanArg(argc, argv, kArgJustCompileModel, false);
191 std::vector<int> inputShape;
192 std::istringstream inputShapeStream(argv[kArgInputShape]);
193 std::string currSizeToken;
194 while (std::getline(inputShapeStream, currSizeToken, ',')) {
195 inputShape.push_back(std::stoi(currSizeToken));
196 }
197 int inputElementSize = getIntArg(argc, argv, kArgInputElementSize, 0);
198 int processFailureRate = getIntArg(argc, argv, kArgProcessFailureRatePercent, 0);
199
200
201 bool mmapModel = getBooleanArg(argc, argv, kArgMmapModel, true);
202
203 // Validate params
204
205 if (!canReadInputFile(modelFileName)) {
206 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
207 "Error reading model file '%s'", modelFileName);
208 return kInvalidArguments;
209 }
210
211 std::vector<InferenceInOutSequence> inputData;
212 if (!justCompileModel) {
213 if (!readInputData(inputDataFileName, inputShape, inputElementSize,
214 &inputData)) {
215 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
216 "Error reading input data file '%s'",
217 inputDataFileName);
218 return kInvalidArguments;
219 }
220 }
221
222 if (numProcesses <= 0 || numThreads <= 0 || durationSeconds <= 0 ||
223 inputElementSize <= 0) {
224 __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Invalid arguments");
225 return kInvalidArguments;
226 }
227
228 __android_log_print(
229 ANDROID_LOG_INFO, LOG_TAG,
230 "Test '%s': running %s of model at path '%s' with input shape [%s] "
231 "(element data size %d),"
232 " %d processes of %d threads each using device '%s' for %d seconds",
233 testName, justCompileModel ? "compilation only" : "full inference",
234 modelFileName, argv[kArgInputShape], inputElementSize, numProcesses,
235 numThreads,
236 nnApiDeviceName.empty() ? "no-device" : nnApiDeviceName.c_str(),
237 durationSeconds);
238
239 srand48(time(NULL) + getpid());
240
241 std::vector<pid_t> children;
242 pid_t pid = 1;
243 bool forkSucceeded = true;
244 bool isSubprocess = false;
245 for (int i = 0; i < numProcesses; i++) {
246 if (pid != 0) {
247 pid = fork();
248 if (pid > 0) {
249 children.push_back(pid);
250 __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "Forked child pid %d",
251 pid);
252 } else if (pid < 0) {
253 forkSucceeded = false;
254 break;
255 } else {
256 isSubprocess = true;
257 }
258 }
259 }
260
261 if (isSubprocess) {
262 __android_log_print(
263 ANDROID_LOG_INFO, LOG_TAG,
264 "%s model '%s': for %d seconds on device '%s' on %d threads",
265 justCompileModel ? "Compiling" : "Running", modelFileName,
266 durationSeconds,
267 nnApiDeviceName.empty() ? "no-device" : nnApiDeviceName.c_str(),
268 numThreads);
269
270 bool shouldKillProcess = (drand48() * 100) <= (double)processFailureRate;
271
272 if (shouldKillProcess) {
273 float killAfter = durationSeconds * drand48();
274 __android_log_print(ANDROID_LOG_INFO, LOG_TAG,
275 "This process will be killed in %f seconds",
276 killAfter);
277 std::thread killer = std::thread([killAfter]() {
278 usleep(killAfter * 1000.0 * 1000);
279 __android_log_print(ANDROID_LOG_INFO, LOG_TAG,
280 "Killing current test process.");
281 kill(getpid(), 9);
282 });
283 killer.detach();
284 }
285
286 std::vector<std::thread> threads;
287 threads.reserve(numThreads);
288 for (int i = 0; i < numThreads; i++) {
289 threads.push_back(std::thread([&]() {
290 runModel(modelFileName, inputData, durationSeconds, nnApiDeviceName,
291 justCompileModel, mmapModel);
292 }));
293 }
294 std::for_each(threads.begin(), threads.end(),
295 [](std::thread& t) { t.join(); });
296 } else {
297 for (auto pid : children) {
298 waitpid(pid, nullptr, 0);
299 }
300 }
301
302 __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "Test '%s': %s returning ",
303 testName, isSubprocess ? "Test process" : "Main process");
304
305 return 0;
306 }
307