1 /**
2 * Copyright 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "run_tflite.h"
18
19 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
20
21 #include <jni.h>
22 #include <string>
23 #include <iomanip>
24 #include <sstream>
25 #include <fcntl.h>
26
27 #include <android/asset_manager_jni.h>
28 #include <android/log.h>
29 #include <android/sharedmem.h>
30 #include <sys/mman.h>
31
32
33 extern "C"
34 JNIEXPORT jlong
35 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_initModel(JNIEnv * env,jobject,jstring _modelFileName,jboolean _useNnApi,jboolean _enableIntermediateTensorsDump,jstring _nnApiDeviceName,jboolean _mmapModel,jstring _nnApiCacheDir)36 Java_com_android_nn_benchmark_core_NNTestBase_initModel(
37 JNIEnv *env,
38 jobject /* this */,
39 jstring _modelFileName,
40 jboolean _useNnApi,
41 jboolean _enableIntermediateTensorsDump,
42 jstring _nnApiDeviceName,
43 jboolean _mmapModel,
44 jstring _nnApiCacheDir) {
45 const char *modelFileName = env->GetStringUTFChars(_modelFileName, NULL);
46 const char *nnApiDeviceName =
47 _nnApiDeviceName == NULL
48 ? NULL
49 : env->GetStringUTFChars(_nnApiDeviceName, NULL);
50 const char *nnApiCacheDir =
51 _nnApiCacheDir == NULL
52 ? NULL
53 : env->GetStringUTFChars(_nnApiCacheDir, NULL);
54 int nnapiErrno = 0;
55 void *handle = BenchmarkModel::create(
56 modelFileName, _useNnApi, _enableIntermediateTensorsDump, &nnapiErrno,
57 nnApiDeviceName, _mmapModel, nnApiCacheDir);
58 env->ReleaseStringUTFChars(_modelFileName, modelFileName);
59 if (_nnApiDeviceName != NULL) {
60 env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
61 }
62
63 if (_useNnApi && nnapiErrno != 0) {
64 jclass nnapiFailureClass = env->FindClass(
65 "com/android/nn/benchmark/core/NnApiDelegationFailure");
66 jmethodID constructor =
67 env->GetMethodID(nnapiFailureClass, "<init>", "(I)V");
68 jobject exception =
69 env->NewObject(nnapiFailureClass, constructor, nnapiErrno);
70 env->Throw(static_cast<jthrowable>(exception));
71 }
72
73 return (jlong)(uintptr_t)handle;
74 }
75
76 extern "C"
77 JNIEXPORT void
78 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(JNIEnv * env,jobject,jlong _modelHandle)79 Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(
80 JNIEnv *env,
81 jobject /* this */,
82 jlong _modelHandle) {
83 BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
84 delete(model);
85 }
86
87 extern "C"
88 JNIEXPORT jboolean
89 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(JNIEnv * env,jobject,jlong _modelHandle,jintArray _inputShape)90 Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(
91 JNIEnv *env,
92 jobject /* this */,
93 jlong _modelHandle,
94 jintArray _inputShape) {
95 BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
96 jint* shapePtr = env->GetIntArrayElements(_inputShape, nullptr);
97 jsize shapeLen = env->GetArrayLength(_inputShape);
98
99 std::vector<int> shape(shapePtr, shapePtr + shapeLen);
100 return model->resizeInputTensors(std::move(shape));
101 }
102
103 /** RAII container for a list of InferenceInOutSequence to handle JNI data release in destructor. */
104 class InferenceInOutSequenceList {
105 public:
106 InferenceInOutSequenceList(JNIEnv *env,
107 const jobject& inOutDataList,
108 bool expectGoldenOutputs);
109 ~InferenceInOutSequenceList();
110
isValid() const111 bool isValid() const { return mValid; }
112
data() const113 const std::vector<InferenceInOutSequence>& data() const { return mData; }
114
115 private:
116 JNIEnv *mEnv; // not owned.
117
118 std::vector<InferenceInOutSequence> mData;
119 std::vector<jbyteArray> mInputArrays;
120 std::vector<jobjectArray> mOutputArrays;
121 bool mValid;
122 };
123
InferenceInOutSequenceList(JNIEnv * env,const jobject & inOutDataList,bool expectGoldenOutputs)124 InferenceInOutSequenceList::InferenceInOutSequenceList(JNIEnv *env,
125 const jobject& inOutDataList,
126 bool expectGoldenOutputs)
127 : mEnv(env), mValid(false) {
128
129 jclass list_class = env->FindClass("java/util/List");
130 if (list_class == nullptr) { return; }
131 jmethodID list_size = env->GetMethodID(list_class, "size", "()I");
132 if (list_size == nullptr) { return; }
133 jmethodID list_get = env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;");
134 if (list_get == nullptr) { return; }
135 jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
136 if (list_add == nullptr) { return; }
137
138 jclass inOutSeq_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOutSequence");
139 if (inOutSeq_class == nullptr) { return; }
140 jmethodID inOutSeq_size = env->GetMethodID(inOutSeq_class, "size", "()I");
141 if (inOutSeq_size == nullptr) { return; }
142 jmethodID inOutSeq_get = env->GetMethodID(inOutSeq_class, "get",
143 "(I)Lcom/android/nn/benchmark/core/InferenceInOut;");
144 if (inOutSeq_get == nullptr) { return; }
145
146 jclass inout_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut");
147 if (inout_class == nullptr) { return; }
148 jfieldID inout_input = env->GetFieldID(inout_class, "mInput", "[B");
149 if (inout_input == nullptr) { return; }
150 jfieldID inout_expectedOutputs = env->GetFieldID(inout_class, "mExpectedOutputs", "[[B");
151 if (inout_expectedOutputs == nullptr) { return; }
152 jfieldID inout_inputCreator = env->GetFieldID(inout_class, "mInputCreator",
153 "Lcom/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface;");
154 if (inout_inputCreator == nullptr) { return; }
155
156
157
158 // Fetch input/output arrays
159 size_t data_count = mEnv->CallIntMethod(inOutDataList, list_size);
160 if (env->ExceptionCheck()) { return; }
161 mData.reserve(data_count);
162
163 jclass inputCreator_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface");
164 if (inputCreator_class == nullptr) { return; }
165 jmethodID createInput_method = env->GetMethodID(inputCreator_class, "createInput", "(Ljava/nio/ByteBuffer;)V");
166 if (createInput_method == nullptr) { return; }
167
168 for (int seq_index = 0; seq_index < data_count; ++seq_index) {
169 jobject inOutSeq = mEnv->CallObjectMethod(inOutDataList, list_get, seq_index);
170 if (mEnv->ExceptionCheck()) { return; }
171
172 size_t seqLen = mEnv->CallIntMethod(inOutSeq, inOutSeq_size);
173 if (mEnv->ExceptionCheck()) { return; }
174
175 mData.push_back(InferenceInOutSequence{});
176 auto& seq = mData.back();
177 seq.reserve(seqLen);
178 for (int i = 0; i < seqLen; ++i) {
179 jobject inout = mEnv->CallObjectMethod(inOutSeq, inOutSeq_get, i);
180 if (mEnv->ExceptionCheck()) { return; }
181
182 uint8_t* input_data = nullptr;
183 size_t input_len = 0;
184 std::function<bool(uint8_t*, size_t)> inputCreator;
185 jbyteArray input = static_cast<jbyteArray>(
186 mEnv->GetObjectField(inout, inout_input));
187 mInputArrays.push_back(input);
188 if (input != nullptr) {
189 input_data = reinterpret_cast<uint8_t*>(
190 mEnv->GetByteArrayElements(input, NULL));
191 input_len = mEnv->GetArrayLength(input);
192 } else {
193 inputCreator = [env, inout, inout_inputCreator, createInput_method](
194 uint8_t* buffer, size_t length) {
195 jobject byteBuffer = env->NewDirectByteBuffer(buffer, length);
196 if (byteBuffer == nullptr) { return false; }
197 jobject creator = env->GetObjectField(inout, inout_inputCreator);
198 if (creator == nullptr) { return false; }
199 env->CallVoidMethod(creator, createInput_method, byteBuffer);
200 env->DeleteLocalRef(byteBuffer);
201 if (env->ExceptionCheck()) { return false; }
202 return true;
203 };
204 }
205
206 jobjectArray expectedOutputs = static_cast<jobjectArray>(
207 mEnv->GetObjectField(inout, inout_expectedOutputs));
208 mOutputArrays.push_back(expectedOutputs);
209 seq.push_back({input_data, input_len, {}, inputCreator});
210
211 // Add expected output to sequence added above
212 if (expectedOutputs != nullptr) {
213 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
214 auto& outputs = seq.back().outputs;
215 outputs.reserve(expectedOutputsLength);
216
217 for (jsize j = 0;j < expectedOutputsLength; ++j) {
218 jbyteArray expectedOutput =
219 static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
220 if (env->ExceptionCheck()) {
221 return;
222 }
223 if (expectedOutput == nullptr) {
224 jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
225 mEnv->ThrowNew(iaeClass, "Null expected output array");
226 return;
227 }
228
229 uint8_t *expectedOutput_data = reinterpret_cast<uint8_t*>(
230 mEnv->GetByteArrayElements(expectedOutput, NULL));
231 size_t expectedOutput_len = mEnv->GetArrayLength(expectedOutput);
232 outputs.push_back({ expectedOutput_data, expectedOutput_len});
233 }
234 } else {
235 if (expectGoldenOutputs) {
236 jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
237 mEnv->ThrowNew(iaeClass, "Expected golden output for every input");
238 return;
239 }
240 }
241 }
242 }
243 mValid = true;
244 }
245
~InferenceInOutSequenceList()246 InferenceInOutSequenceList::~InferenceInOutSequenceList() {
247 // Note that we may land here with a pending JNI exception so cannot call
248 // java objects.
249 int arrayIndex = 0;
250 for (int seq_index = 0; seq_index < mData.size(); ++seq_index) {
251 for (int i = 0; i < mData[seq_index].size(); ++i) {
252 jbyteArray input = mInputArrays[arrayIndex];
253 if (input != nullptr) {
254 mEnv->ReleaseByteArrayElements(
255 input, reinterpret_cast<jbyte*>(mData[seq_index][i].input), JNI_ABORT);
256 }
257 jobjectArray expectedOutputs = mOutputArrays[arrayIndex];
258 if (expectedOutputs != nullptr) {
259 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
260 if (expectedOutputsLength != mData[seq_index][i].outputs.size()) {
261 // Should not happen? :)
262 jclass iaeClass = mEnv->FindClass("java/lang/IllegalStateException");
263 mEnv->ThrowNew(iaeClass, "Mismatch of the size of expected outputs jni array "
264 "and internal array of its bufers");
265 return;
266 }
267
268 for (jsize j = 0;j < expectedOutputsLength; ++j) {
269 jbyteArray expectedOutput = static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
270 mEnv->ReleaseByteArrayElements(
271 expectedOutput, reinterpret_cast<jbyte*>(mData[seq_index][i].outputs[j].ptr),
272 JNI_ABORT);
273 }
274 }
275 arrayIndex++;
276 }
277 }
278 }
279
280 extern "C"
281 JNIEXPORT jboolean
282 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(JNIEnv * env,jobject,jlong _modelHandle,jobject inOutDataList,jobject resultList,jint inferencesSeqMaxCount,jfloat timeoutSec,jint flags)283 Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(
284 JNIEnv *env,
285 jobject /* this */,
286 jlong _modelHandle,
287 jobject inOutDataList,
288 jobject resultList,
289 jint inferencesSeqMaxCount,
290 jfloat timeoutSec,
291 jint flags) {
292
293 BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
294
295 jclass list_class = env->FindClass("java/util/List");
296 if (list_class == nullptr) { return false; }
297 jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
298 if (list_add == nullptr) { return false; }
299
300 jclass result_class = env->FindClass("com/android/nn/benchmark/core/InferenceResult");
301 if (result_class == nullptr) { return false; }
302 jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "(F[F[F[[BII)V");
303 if (result_ctor == nullptr) { return false; }
304
305 std::vector<InferenceResult> result;
306
307 const bool expectGoldenOutputs = (flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0;
308 InferenceInOutSequenceList data(env, inOutDataList, expectGoldenOutputs);
309 if (!data.isValid()) {
310 return false;
311 }
312
313 // TODO: Remove success boolean from this method and throw an exception in case of problems
314 bool success = model->benchmark(data.data(), inferencesSeqMaxCount, timeoutSec, flags, &result);
315
316 // Generate results
317 if (success) {
318 for (const InferenceResult &rentry : result) {
319 jobjectArray inferenceOutputs = nullptr;
320 jfloatArray meanSquareErrorArray = nullptr;
321 jfloatArray maxSingleErrorArray = nullptr;
322
323 if ((flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0) {
324 meanSquareErrorArray = env->NewFloatArray(rentry.meanSquareErrors.size());
325 if (env->ExceptionCheck()) { return false; }
326 maxSingleErrorArray = env->NewFloatArray(rentry.maxSingleErrors.size());
327 if (env->ExceptionCheck()) { return false; }
328 {
329 jfloat *bytes = env->GetFloatArrayElements(meanSquareErrorArray, nullptr);
330 memcpy(bytes,
331 &rentry.meanSquareErrors[0],
332 rentry.meanSquareErrors.size() * sizeof(float));
333 env->ReleaseFloatArrayElements(meanSquareErrorArray, bytes, 0);
334 }
335 {
336 jfloat *bytes = env->GetFloatArrayElements(maxSingleErrorArray, nullptr);
337 memcpy(bytes,
338 &rentry.maxSingleErrors[0],
339 rentry.maxSingleErrors.size() * sizeof(float));
340 env->ReleaseFloatArrayElements(maxSingleErrorArray, bytes, 0);
341 }
342 }
343
344 if ((flags & FLAG_DISCARD_INFERENCE_OUTPUT) == 0) {
345 jclass byteArrayClass = env->FindClass("[B");
346
347 inferenceOutputs = env->NewObjectArray(
348 rentry.inferenceOutputs.size(),
349 byteArrayClass, nullptr);
350
351 for (int i = 0;i < rentry.inferenceOutputs.size();++i) {
352 jbyteArray inferenceOutput = nullptr;
353 inferenceOutput = env->NewByteArray(rentry.inferenceOutputs[i].size());
354 if (env->ExceptionCheck()) { return false; }
355 jbyte *bytes = env->GetByteArrayElements(inferenceOutput, nullptr);
356 memcpy(bytes, &rentry.inferenceOutputs[i][0], rentry.inferenceOutputs[i].size());
357 env->ReleaseByteArrayElements(inferenceOutput, bytes, 0);
358 env->SetObjectArrayElement(inferenceOutputs, i, inferenceOutput);
359 }
360 }
361
362 jobject object = env->NewObject(
363 result_class, result_ctor, rentry.computeTimeSec,
364 meanSquareErrorArray, maxSingleErrorArray, inferenceOutputs,
365 rentry.inputOutputSequenceIndex, rentry.inputOutputIndex);
366 if (env->ExceptionCheck() || object == NULL) { return false; }
367
368 env->CallBooleanMethod(resultList, list_add, object);
369 if (env->ExceptionCheck()) { return false; }
370
371 // Releasing local references to objects to avoid local reference table overflow
372 // if tests is set to run for long time.
373 if (meanSquareErrorArray) {
374 env->DeleteLocalRef(meanSquareErrorArray);
375 }
376 if (maxSingleErrorArray) {
377 env->DeleteLocalRef(maxSingleErrorArray);
378 }
379 env->DeleteLocalRef(object);
380 }
381 }
382
383 return success;
384 }
385
386 extern "C"
387 JNIEXPORT void
388 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(JNIEnv * env,jobject,jlong _modelHandle,jstring dumpPath,jobject inOutDataList)389 Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(
390 JNIEnv *env,
391 jobject /* this */,
392 jlong _modelHandle,
393 jstring dumpPath,
394 jobject inOutDataList) {
395
396 BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
397
398 InferenceInOutSequenceList data(env, inOutDataList, /*expectGoldenOutputs=*/false);
399 if (!data.isValid()) {
400 return;
401 }
402
403 const char *dumpPathStr = env->GetStringUTFChars(dumpPath, JNI_FALSE);
404 model->dumpAllLayers(dumpPathStr, data.data());
405 env->ReleaseStringUTFChars(dumpPath, dumpPathStr);
406 }
407
408 extern "C"
409 JNIEXPORT jboolean
410 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_hasAccelerator()411 Java_com_android_nn_benchmark_core_NNTestBase_hasAccelerator() {
412 uint32_t device_count = 0;
413 NnApiImplementation()->ANeuralNetworks_getDeviceCount(&device_count);
414 // We only consider a real device, not 'nnapi-reference'.
415 return device_count > 1;
416 }
417
418 extern "C"
419 JNIEXPORT jboolean
420 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_getAcceleratorNames(JNIEnv * env,jclass,jobject resultList)421 Java_com_android_nn_benchmark_core_NNTestBase_getAcceleratorNames(
422 JNIEnv *env,
423 jclass, /* clazz */
424 jobject resultList
425 ) {
426 uint32_t device_count = 0;
427 auto nnapi_result = NnApiImplementation()->ANeuralNetworks_getDeviceCount(&device_count);
428 if (nnapi_result != 0) {
429 return false;
430 }
431
432 jclass list_class = env->FindClass("java/util/List");
433 if (list_class == nullptr) { return false; }
434 jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
435 if (list_add == nullptr) { return false; }
436
437 for (int i = 0; i < device_count; i++) {
438 ANeuralNetworksDevice* device = nullptr;
439 nnapi_result = NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
440 if (nnapi_result != 0) {
441 return false;
442 }
443 const char* buffer = nullptr;
444 nnapi_result = NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer);
445 if (nnapi_result != 0) {
446 return false;
447 }
448
449 auto device_name = env->NewStringUTF(buffer);
450
451 env->CallBooleanMethod(resultList, list_add, device_name);
452 if (env->ExceptionCheck()) { return false; }
453 }
454 return true;
455 }
456
convertToJfloatArray(JNIEnv * env,const std::vector<float> & from)457 static jfloatArray convertToJfloatArray(JNIEnv* env, const std::vector<float>& from) {
458 jfloatArray to = env->NewFloatArray(from.size());
459 if (env->ExceptionCheck()) {
460 return nullptr;
461 }
462 jfloat* bytes = env->GetFloatArrayElements(to, nullptr);
463 memcpy(bytes, from.data(), from.size() * sizeof(float));
464 env->ReleaseFloatArrayElements(to, bytes, 0);
465 return to;
466 }
467
468 extern "C" JNIEXPORT jobject JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runCompilationBenchmark(JNIEnv * env,jobject,jlong _modelHandle,jint maxNumIterations,jfloat warmupTimeoutSec,jfloat runTimeoutSec)469 Java_com_android_nn_benchmark_core_NNTestBase_runCompilationBenchmark(
470 JNIEnv* env,
471 jobject /* this */,
472 jlong _modelHandle,
473 jint maxNumIterations,
474 jfloat warmupTimeoutSec,
475 jfloat runTimeoutSec) {
476 BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
477
478 jclass result_class = env->FindClass("com/android/nn/benchmark/core/CompilationBenchmarkResult");
479 if (result_class == nullptr) return nullptr;
480 jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "([F[F[FI)V");
481 if (result_ctor == nullptr) return nullptr;
482
483 CompilationBenchmarkResult result;
484 bool success =
485 model->benchmarkCompilation(maxNumIterations, warmupTimeoutSec, runTimeoutSec, &result);
486 if (!success) return nullptr;
487
488 // Convert cpp CompilationBenchmarkResult struct to java.
489 jfloatArray compileWithoutCacheArray =
490 convertToJfloatArray(env, result.compileWithoutCacheTimeSec);
491 if (compileWithoutCacheArray == nullptr) return nullptr;
492
493 // saveToCache and prepareFromCache results may not exist.
494 jfloatArray saveToCacheArray = nullptr;
495 if (result.saveToCacheTimeSec) {
496 saveToCacheArray = convertToJfloatArray(env, result.saveToCacheTimeSec.value());
497 if (saveToCacheArray == nullptr) return nullptr;
498 }
499 jfloatArray prepareFromCacheArray = nullptr;
500 if (result.prepareFromCacheTimeSec) {
501 prepareFromCacheArray = convertToJfloatArray(env, result.prepareFromCacheTimeSec.value());
502 if (prepareFromCacheArray == nullptr) return nullptr;
503 }
504
505 jobject object = env->NewObject(result_class, result_ctor, compileWithoutCacheArray,
506 saveToCacheArray, prepareFromCacheArray, result.cacheSizeBytes);
507 if (env->ExceptionCheck()) return nullptr;
508 return object;
509 }
510