1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <android/log.h>
17 #include <fcntl.h>
18 #include <jni.h>
19 #include <signal.h>
20 #include <sys/types.h>
21 #include <sys/wait.h>
22 #include <time.h>
23 #include <unistd.h>
24 
25 #include <fstream>
26 #include <memory>
27 #include <string>
28 #include <sstream>
29 #include <thread>
30 #include <vector>
31 
32 #include "run_tflite.h"
33 #include "tensorflow/lite/interpreter.h"
34 #include "tensorflow/lite/model.h"
35 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
36 
37 #define LOG_TAG "NN_MPROC_STRESS"
38 
39 constexpr int kInvalidArguments = -1;
40 
41 enum Arguments : int {
42   kArgModelPath = 1,
43   kArgInputDataPath,
44   kArgInputShape,
45   kArgInputElementSize,
46   kArgProcessCount,
47   kArgThreadCount,
48   kArgDurationSeconds,
49   kArgTestName,
50   kArgJustCompileModel,
51   kArgProcessFailureRatePercent,
52   kArgNnApiDeviceName,
53   kArgMmapModel
54 };
55 
56 constexpr int kMandatoryArgsCount = 9;
57 
58 const char* kUsage =
59     R"""(%s modelFileName inputDataFile inputShape inputElementByteSize procCount threadCount durationInSeconds testName justCompileModel [processFailureRate] [nnapiDeviceName] [mmapModel]
60 
61                           where:
62                               inputShape comma separated list of integers (e.g. '1,224,224,3')
63                               justCompileModel: true/false)
64                               processFailureRate: 0 to 100 percent probability of having one of the client processes failing. Defaults to 0.)
65                               mmapModel: true/false select if the TFLite model should be memory mapped to the given file or created from program memory)""";
66 
canReadInputFile(const char * path)67 bool canReadInputFile(const char* path) {
68   std::string modelFileName(path);
69   std::ifstream fstream(modelFileName);
70   std::stringstream readBuffer;
71   readBuffer << fstream.rdbuf();
72   return fstream.good();
73 }
74 
readInputData(const char * inputDataFileName,std::vector<int> input_shape,int inputElementSize,std::vector<InferenceInOutSequence> * result)75 bool readInputData(const char* inputDataFileName, std::vector<int> input_shape,
76                    int inputElementSize,
77                    std::vector<InferenceInOutSequence>* result) {
78   int inputElementCount = 1;
79   std::for_each(
80       input_shape.begin(), input_shape.end(),
81       [&inputElementCount](int dimSize) { inputElementCount *= dimSize; });
82   size_t inputDataSizeBytes = inputElementCount * inputElementSize;
83 
84   std::ifstream dataFile;
85   dataFile.open(inputDataFileName);
86   if (!dataFile) {
87     return false;
88   }
89 
90   std::function<bool(uint8_t*, size_t)> failToGenerateData =
91       [](uint8_t*, size_t) { return false; };
92   while (!dataFile.eof()) {
93     std::unique_ptr<uint8_t[]> dataBuffer =
94         std::make_unique<uint8_t[]>(inputDataSizeBytes);
95     if (!dataFile.read(reinterpret_cast<char*>(dataBuffer.get()),
96                        inputDataSizeBytes)) {
97       break;
98     }
99     InferenceInOut entry{
100         dataBuffer.release(), inputDataSizeBytes, {}, failToGenerateData};
101     result->push_back({entry});
102   }
103 
104   return result;
105 }
106 
runModel(const char * modelFileName,const std::vector<InferenceInOutSequence> & data,int durationSeconds,const std::string & nnApiDeviceName,bool justCompileModel,bool mmapModel)107 bool runModel(const char* modelFileName,
108               const std::vector<InferenceInOutSequence>& data,
109               int durationSeconds, const std::string& nnApiDeviceName,
110               bool justCompileModel, bool mmapModel) {
111   if (justCompileModel) {
112     std::time_t startTime = std::time(nullptr);
113     while (std::difftime(std::time(nullptr), startTime) < durationSeconds) {
114       int nnapiErrno = 0;
115       std::unique_ptr<BenchmarkModel> model(BenchmarkModel::create(
116           modelFileName, /*useNnApi=*/true,
117           /*enableIntermediateTensorsDump=*/false,
118           &nnapiErrno,
119           nnApiDeviceName.empty() ? nullptr : nnApiDeviceName.c_str(), mmapModel,
120           /*nnapi_cache_dir=*/nullptr));
121 
122       if (!model) {
123         __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Error creating model");
124         return false;
125       }
126 
127       // sleeping from 300ms to 800ms
128       constexpr int kMinPauseMs = 300;
129       constexpr int kMaxPauseMs = 800;
130       int sleepForMs = kMinPauseMs + (drand48() * (kMaxPauseMs - kMinPauseMs));
131       usleep(sleepForMs * 1000);
132     }
133 
134     return true;
135   } else {
136     int nnapiErrno = 0;
137     std::unique_ptr<BenchmarkModel> model(BenchmarkModel::create(
138         modelFileName, /*useNnApi=*/true,
139         /*enableIntermediateTensorsDump=*/false,
140         &nnapiErrno,
141         nnApiDeviceName.empty() ? nullptr : nnApiDeviceName.c_str(), mmapModel,
142         /*nnapi_cache_dir=*/nullptr));
143 
144     if (!model) {
145       __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Error creating model");
146       return false;
147     }
148 
149     std::vector<InferenceResult> result;
150     constexpr int flags =
151         FLAG_DISCARD_INFERENCE_OUTPUT | FLAG_IGNORE_GOLDEN_OUTPUT;
152     return model->benchmark(data, std::numeric_limits<int>::max(),
153                             durationSeconds, flags, &result);
154   }
155 }
156 
getBooleanArg(int argc,char * argv[],int argIndex,bool defaultValue)157 bool getBooleanArg(int argc, char* argv[], int argIndex, bool defaultValue) {
158     if (argc > argIndex) {
159         std::string argAsString(argv[argIndex]);
160         return argAsString == "true";
161     } else {
162         return defaultValue;
163     }
164 }
165 
getIntArg(int argc,char * argv[],int argIndex,int defaultValue)166 int getIntArg(int argc, char* argv[], int argIndex, int defaultValue) {
167     if (argc > argIndex) {
168         return std::atoi(argv[argIndex]);
169     } else {
170         return defaultValue;
171     }
172 }
173 
main(int argc,char * argv[])174 int main(int argc, char* argv[]) {
175   if (argc < kMandatoryArgsCount) {
176     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, kUsage, kMandatoryArgsCount,
177                         argc, argv[0]);
178     return kInvalidArguments;
179   }
180 
181   const char* modelFileName = argv[kArgModelPath];
182   const char* inputDataFileName = argv[kArgInputDataPath];
183   const char* testName = argv[kArgTestName];
184   std::string nnApiDeviceName{
185       argc > kArgNnApiDeviceName ? argv[kArgNnApiDeviceName] : ""};
186   int numProcesses = getIntArg(argc, argv, kArgProcessCount, 0);
187   int numThreads = getIntArg(argc, argv, kArgThreadCount, 0);
188   int durationSeconds = getIntArg(argc, argv, kArgDurationSeconds, 0);
189   bool justCompileModel =
190       getBooleanArg(argc, argv, kArgJustCompileModel, false);
191   std::vector<int> inputShape;
192   std::istringstream inputShapeStream(argv[kArgInputShape]);
193   std::string currSizeToken;
194   while (std::getline(inputShapeStream, currSizeToken, ',')) {
195     inputShape.push_back(std::stoi(currSizeToken));
196   }
197   int inputElementSize = getIntArg(argc, argv, kArgInputElementSize, 0);
198   int processFailureRate = getIntArg(argc, argv, kArgProcessFailureRatePercent, 0);
199 
200 
201   bool mmapModel = getBooleanArg(argc, argv, kArgMmapModel, true);
202 
203   // Validate params
204 
205   if (!canReadInputFile(modelFileName)) {
206     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
207                         "Error reading model file '%s'", modelFileName);
208     return kInvalidArguments;
209   }
210 
211   std::vector<InferenceInOutSequence> inputData;
212   if (!justCompileModel) {
213     if (!readInputData(inputDataFileName, inputShape, inputElementSize,
214                        &inputData)) {
215       __android_log_print(ANDROID_LOG_ERROR, LOG_TAG,
216                           "Error reading input data file '%s'",
217                           inputDataFileName);
218       return kInvalidArguments;
219     }
220   }
221 
222   if (numProcesses <= 0 || numThreads <= 0 || durationSeconds <= 0 ||
223       inputElementSize <= 0) {
224     __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Invalid arguments");
225     return kInvalidArguments;
226   }
227 
228   __android_log_print(
229       ANDROID_LOG_INFO, LOG_TAG,
230       "Test '%s': running %s of model at path '%s' with input shape [%s] "
231       "(element data size %d),"
232       " %d processes of %d threads each using device '%s' for %d seconds",
233       testName, justCompileModel ? "compilation only" : "full inference",
234       modelFileName, argv[kArgInputShape], inputElementSize, numProcesses,
235       numThreads,
236       nnApiDeviceName.empty() ? "no-device" : nnApiDeviceName.c_str(),
237       durationSeconds);
238 
239   srand48(time(NULL) + getpid());
240 
241   std::vector<pid_t> children;
242   pid_t pid = 1;
243   bool forkSucceeded = true;
244   bool isSubprocess = false;
245   for (int i = 0; i < numProcesses; i++) {
246     if (pid != 0) {
247       pid = fork();
248       if (pid > 0) {
249         children.push_back(pid);
250         __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "Forked child pid %d",
251                             pid);
252       } else if (pid < 0) {
253         forkSucceeded = false;
254         break;
255       } else {
256         isSubprocess = true;
257       }
258     }
259   }
260 
261   if (isSubprocess) {
262     __android_log_print(
263         ANDROID_LOG_INFO, LOG_TAG,
264         "%s model '%s': for %d seconds on device '%s' on %d threads",
265         justCompileModel ? "Compiling" : "Running", modelFileName,
266         durationSeconds,
267         nnApiDeviceName.empty() ? "no-device" : nnApiDeviceName.c_str(),
268         numThreads);
269 
270     bool shouldKillProcess = (drand48() * 100) <= (double)processFailureRate;
271 
272     if (shouldKillProcess) {
273       float killAfter = durationSeconds * drand48();
274       __android_log_print(ANDROID_LOG_INFO, LOG_TAG,
275                           "This process will be killed in %f seconds",
276                           killAfter);
277       std::thread killer = std::thread([killAfter]() {
278         usleep(killAfter * 1000.0 * 1000);
279         __android_log_print(ANDROID_LOG_INFO, LOG_TAG,
280                             "Killing current test process.");
281         kill(getpid(), 9);
282       });
283       killer.detach();
284     }
285 
286     std::vector<std::thread> threads;
287     threads.reserve(numThreads);
288     for (int i = 0; i < numThreads; i++) {
289       threads.push_back(std::thread([&]() {
290         runModel(modelFileName, inputData, durationSeconds, nnApiDeviceName,
291                  justCompileModel, mmapModel);
292       }));
293     }
294     std::for_each(threads.begin(), threads.end(),
295                   [](std::thread& t) { t.join(); });
296   } else {
297     for (auto pid : children) {
298       waitpid(pid, nullptr, 0);
299     }
300   }
301 
302   __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "Test '%s': %s returning ",
303                       testName, isSubprocess ? "Test process" : "Main process");
304 
305   return 0;
306 }
307