1 /**
2  * Copyright 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "run_tflite.h"
18 
19 #include "tensorflow/lite/nnapi/nnapi_implementation.h"
20 
21 #include <jni.h>
22 #include <string>
23 #include <iomanip>
24 #include <sstream>
25 #include <fcntl.h>
26 
27 #include <android/asset_manager_jni.h>
28 #include <android/log.h>
29 #include <android/sharedmem.h>
30 #include <sys/mman.h>
31 
32 
33 extern "C"
34 JNIEXPORT jlong
35 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_initModel(JNIEnv * env,jobject,jstring _modelFileName,jboolean _useNnApi,jboolean _enableIntermediateTensorsDump,jstring _nnApiDeviceName,jboolean _mmapModel,jstring _nnApiCacheDir)36 Java_com_android_nn_benchmark_core_NNTestBase_initModel(
37         JNIEnv *env,
38         jobject /* this */,
39         jstring _modelFileName,
40         jboolean _useNnApi,
41         jboolean _enableIntermediateTensorsDump,
42         jstring _nnApiDeviceName,
43         jboolean _mmapModel,
44         jstring _nnApiCacheDir) {
45     const char *modelFileName = env->GetStringUTFChars(_modelFileName, NULL);
46     const char *nnApiDeviceName =
47         _nnApiDeviceName == NULL
48             ? NULL
49             : env->GetStringUTFChars(_nnApiDeviceName, NULL);
50     const char *nnApiCacheDir =
51         _nnApiCacheDir == NULL
52             ? NULL
53             : env->GetStringUTFChars(_nnApiCacheDir, NULL);
54     int nnapiErrno = 0;
55     void *handle = BenchmarkModel::create(
56         modelFileName, _useNnApi, _enableIntermediateTensorsDump, &nnapiErrno,
57         nnApiDeviceName, _mmapModel, nnApiCacheDir);
58     env->ReleaseStringUTFChars(_modelFileName, modelFileName);
59     if (_nnApiDeviceName != NULL) {
60         env->ReleaseStringUTFChars(_nnApiDeviceName, nnApiDeviceName);
61     }
62 
63     if (_useNnApi && nnapiErrno != 0) {
64       jclass nnapiFailureClass = env->FindClass(
65           "com/android/nn/benchmark/core/NnApiDelegationFailure");
66       jmethodID constructor =
67           env->GetMethodID(nnapiFailureClass, "<init>", "(I)V");
68       jobject exception =
69           env->NewObject(nnapiFailureClass, constructor, nnapiErrno);
70       env->Throw(static_cast<jthrowable>(exception));
71     }
72 
73     return (jlong)(uintptr_t)handle;
74 }
75 
76 extern "C"
77 JNIEXPORT void
78 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(JNIEnv * env,jobject,jlong _modelHandle)79 Java_com_android_nn_benchmark_core_NNTestBase_destroyModel(
80         JNIEnv *env,
81         jobject /* this */,
82         jlong _modelHandle) {
83     BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
84     delete(model);
85 }
86 
87 extern "C"
88 JNIEXPORT jboolean
89 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(JNIEnv * env,jobject,jlong _modelHandle,jintArray _inputShape)90 Java_com_android_nn_benchmark_core_NNTestBase_resizeInputTensors(
91         JNIEnv *env,
92         jobject /* this */,
93         jlong _modelHandle,
94         jintArray _inputShape) {
95     BenchmarkModel* model = (BenchmarkModel *) _modelHandle;
96     jint* shapePtr = env->GetIntArrayElements(_inputShape, nullptr);
97     jsize shapeLen = env->GetArrayLength(_inputShape);
98 
99     std::vector<int> shape(shapePtr, shapePtr + shapeLen);
100     return model->resizeInputTensors(std::move(shape));
101 }
102 
103 /** RAII container for a list of InferenceInOutSequence to handle JNI data release in destructor. */
104 class InferenceInOutSequenceList {
105 public:
106     InferenceInOutSequenceList(JNIEnv *env,
107                                const jobject& inOutDataList,
108                                bool expectGoldenOutputs);
109     ~InferenceInOutSequenceList();
110 
isValid() const111     bool isValid() const { return mValid; }
112 
data() const113     const std::vector<InferenceInOutSequence>& data() const { return mData; }
114 
115 private:
116     JNIEnv *mEnv;  // not owned.
117 
118     std::vector<InferenceInOutSequence> mData;
119     std::vector<jbyteArray> mInputArrays;
120     std::vector<jobjectArray> mOutputArrays;
121     bool mValid;
122 };
123 
InferenceInOutSequenceList(JNIEnv * env,const jobject & inOutDataList,bool expectGoldenOutputs)124 InferenceInOutSequenceList::InferenceInOutSequenceList(JNIEnv *env,
125                                                        const jobject& inOutDataList,
126                                                        bool expectGoldenOutputs)
127     : mEnv(env), mValid(false) {
128 
129     jclass list_class = env->FindClass("java/util/List");
130     if (list_class == nullptr) { return; }
131     jmethodID list_size = env->GetMethodID(list_class, "size", "()I");
132     if (list_size == nullptr) { return; }
133     jmethodID list_get = env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;");
134     if (list_get == nullptr) { return; }
135     jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
136     if (list_add == nullptr) { return; }
137 
138     jclass inOutSeq_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOutSequence");
139     if (inOutSeq_class == nullptr) { return; }
140     jmethodID inOutSeq_size = env->GetMethodID(inOutSeq_class, "size", "()I");
141     if (inOutSeq_size == nullptr) { return; }
142     jmethodID inOutSeq_get = env->GetMethodID(inOutSeq_class, "get",
143                                               "(I)Lcom/android/nn/benchmark/core/InferenceInOut;");
144     if (inOutSeq_get == nullptr) { return; }
145 
146     jclass inout_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut");
147     if (inout_class == nullptr) { return; }
148     jfieldID inout_input = env->GetFieldID(inout_class, "mInput", "[B");
149     if (inout_input == nullptr) { return; }
150     jfieldID inout_expectedOutputs = env->GetFieldID(inout_class, "mExpectedOutputs", "[[B");
151     if (inout_expectedOutputs == nullptr) { return; }
152     jfieldID inout_inputCreator = env->GetFieldID(inout_class, "mInputCreator",
153             "Lcom/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface;");
154     if (inout_inputCreator == nullptr) { return; }
155 
156 
157 
158     // Fetch input/output arrays
159     size_t data_count = mEnv->CallIntMethod(inOutDataList, list_size);
160     if (env->ExceptionCheck()) { return; }
161     mData.reserve(data_count);
162 
163     jclass inputCreator_class = env->FindClass("com/android/nn/benchmark/core/InferenceInOut$InputCreatorInterface");
164     if (inputCreator_class == nullptr) { return; }
165     jmethodID createInput_method = env->GetMethodID(inputCreator_class, "createInput", "(Ljava/nio/ByteBuffer;)V");
166     if (createInput_method == nullptr) { return; }
167 
168     for (int seq_index = 0; seq_index < data_count; ++seq_index) {
169         jobject inOutSeq = mEnv->CallObjectMethod(inOutDataList, list_get, seq_index);
170         if (mEnv->ExceptionCheck()) { return; }
171 
172         size_t seqLen = mEnv->CallIntMethod(inOutSeq, inOutSeq_size);
173         if (mEnv->ExceptionCheck()) { return; }
174 
175         mData.push_back(InferenceInOutSequence{});
176         auto& seq = mData.back();
177         seq.reserve(seqLen);
178         for (int i = 0; i < seqLen; ++i) {
179             jobject inout = mEnv->CallObjectMethod(inOutSeq, inOutSeq_get, i);
180             if (mEnv->ExceptionCheck()) { return; }
181 
182             uint8_t* input_data = nullptr;
183             size_t input_len = 0;
184             std::function<bool(uint8_t*, size_t)> inputCreator;
185             jbyteArray input = static_cast<jbyteArray>(
186                     mEnv->GetObjectField(inout, inout_input));
187             mInputArrays.push_back(input);
188             if (input != nullptr) {
189                 input_data = reinterpret_cast<uint8_t*>(
190                         mEnv->GetByteArrayElements(input, NULL));
191                 input_len = mEnv->GetArrayLength(input);
192             } else {
193                 inputCreator = [env, inout, inout_inputCreator, createInput_method](
194                         uint8_t* buffer, size_t length) {
195                     jobject byteBuffer = env->NewDirectByteBuffer(buffer, length);
196                     if (byteBuffer == nullptr) { return false; }
197                     jobject creator = env->GetObjectField(inout, inout_inputCreator);
198                     if (creator == nullptr) { return false; }
199                     env->CallVoidMethod(creator, createInput_method, byteBuffer);
200                     env->DeleteLocalRef(byteBuffer);
201                     if (env->ExceptionCheck()) { return false; }
202                     return true;
203                 };
204             }
205 
206             jobjectArray expectedOutputs = static_cast<jobjectArray>(
207                     mEnv->GetObjectField(inout, inout_expectedOutputs));
208             mOutputArrays.push_back(expectedOutputs);
209             seq.push_back({input_data, input_len, {}, inputCreator});
210 
211             // Add expected output to sequence added above
212             if (expectedOutputs != nullptr) {
213                 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
214                 auto& outputs = seq.back().outputs;
215                 outputs.reserve(expectedOutputsLength);
216 
217                 for (jsize j = 0;j < expectedOutputsLength; ++j) {
218                     jbyteArray expectedOutput =
219                             static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
220                     if (env->ExceptionCheck()) {
221                         return;
222                     }
223                     if (expectedOutput == nullptr) {
224                         jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
225                         mEnv->ThrowNew(iaeClass, "Null expected output array");
226                         return;
227                     }
228 
229                     uint8_t *expectedOutput_data = reinterpret_cast<uint8_t*>(
230                                         mEnv->GetByteArrayElements(expectedOutput, NULL));
231                     size_t expectedOutput_len = mEnv->GetArrayLength(expectedOutput);
232                     outputs.push_back({ expectedOutput_data, expectedOutput_len});
233                 }
234             } else {
235                 if (expectGoldenOutputs) {
236                     jclass iaeClass = mEnv->FindClass("java/lang/IllegalArgumentException");
237                     mEnv->ThrowNew(iaeClass, "Expected golden output for every input");
238                     return;
239                 }
240             }
241         }
242     }
243     mValid = true;
244 }
245 
~InferenceInOutSequenceList()246 InferenceInOutSequenceList::~InferenceInOutSequenceList() {
247     // Note that we may land here with a pending JNI exception so cannot call
248     // java objects.
249     int arrayIndex = 0;
250     for (int seq_index = 0; seq_index < mData.size(); ++seq_index) {
251         for (int i = 0; i < mData[seq_index].size(); ++i) {
252             jbyteArray input = mInputArrays[arrayIndex];
253             if (input != nullptr) {
254                 mEnv->ReleaseByteArrayElements(
255                         input, reinterpret_cast<jbyte*>(mData[seq_index][i].input), JNI_ABORT);
256             }
257             jobjectArray expectedOutputs = mOutputArrays[arrayIndex];
258             if (expectedOutputs != nullptr) {
259                 jsize expectedOutputsLength = mEnv->GetArrayLength(expectedOutputs);
260                 if (expectedOutputsLength != mData[seq_index][i].outputs.size()) {
261                     // Should not happen? :)
262                     jclass iaeClass = mEnv->FindClass("java/lang/IllegalStateException");
263                     mEnv->ThrowNew(iaeClass, "Mismatch of the size of expected outputs jni array "
264                                    "and internal array of its bufers");
265                     return;
266                 }
267 
268                 for (jsize j = 0;j < expectedOutputsLength; ++j) {
269                     jbyteArray expectedOutput = static_cast<jbyteArray>(mEnv->GetObjectArrayElement(expectedOutputs, j));
270                     mEnv->ReleaseByteArrayElements(
271                         expectedOutput, reinterpret_cast<jbyte*>(mData[seq_index][i].outputs[j].ptr),
272                         JNI_ABORT);
273                 }
274             }
275             arrayIndex++;
276         }
277     }
278 }
279 
280 extern "C"
281 JNIEXPORT jboolean
282 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(JNIEnv * env,jobject,jlong _modelHandle,jobject inOutDataList,jobject resultList,jint inferencesSeqMaxCount,jfloat timeoutSec,jint flags)283 Java_com_android_nn_benchmark_core_NNTestBase_runBenchmark(
284         JNIEnv *env,
285         jobject /* this */,
286         jlong _modelHandle,
287         jobject inOutDataList,
288         jobject resultList,
289         jint inferencesSeqMaxCount,
290         jfloat timeoutSec,
291         jint flags) {
292 
293     BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
294 
295     jclass list_class = env->FindClass("java/util/List");
296     if (list_class == nullptr) { return false; }
297     jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
298     if (list_add == nullptr) { return false; }
299 
300     jclass result_class = env->FindClass("com/android/nn/benchmark/core/InferenceResult");
301     if (result_class == nullptr) { return false; }
302     jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "(F[F[F[[BII)V");
303     if (result_ctor == nullptr) { return false; }
304 
305     std::vector<InferenceResult> result;
306 
307     const bool expectGoldenOutputs = (flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0;
308     InferenceInOutSequenceList data(env, inOutDataList, expectGoldenOutputs);
309     if (!data.isValid()) {
310         return false;
311     }
312 
313     // TODO: Remove success boolean from this method and throw an exception in case of problems
314     bool success = model->benchmark(data.data(), inferencesSeqMaxCount, timeoutSec, flags, &result);
315 
316     // Generate results
317     if (success) {
318         for (const InferenceResult &rentry : result) {
319             jobjectArray inferenceOutputs = nullptr;
320             jfloatArray meanSquareErrorArray = nullptr;
321             jfloatArray maxSingleErrorArray = nullptr;
322 
323             if ((flags & FLAG_IGNORE_GOLDEN_OUTPUT) == 0) {
324                 meanSquareErrorArray = env->NewFloatArray(rentry.meanSquareErrors.size());
325                 if (env->ExceptionCheck()) { return false; }
326                 maxSingleErrorArray = env->NewFloatArray(rentry.maxSingleErrors.size());
327                 if (env->ExceptionCheck()) { return false; }
328                 {
329                     jfloat *bytes = env->GetFloatArrayElements(meanSquareErrorArray, nullptr);
330                     memcpy(bytes,
331                            &rentry.meanSquareErrors[0],
332                            rentry.meanSquareErrors.size() * sizeof(float));
333                     env->ReleaseFloatArrayElements(meanSquareErrorArray, bytes, 0);
334                 }
335                 {
336                     jfloat *bytes = env->GetFloatArrayElements(maxSingleErrorArray, nullptr);
337                     memcpy(bytes,
338                            &rentry.maxSingleErrors[0],
339                            rentry.maxSingleErrors.size() * sizeof(float));
340                     env->ReleaseFloatArrayElements(maxSingleErrorArray, bytes, 0);
341                 }
342             }
343 
344             if ((flags & FLAG_DISCARD_INFERENCE_OUTPUT) == 0) {
345                 jclass byteArrayClass = env->FindClass("[B");
346 
347                 inferenceOutputs = env->NewObjectArray(
348                     rentry.inferenceOutputs.size(),
349                     byteArrayClass, nullptr);
350 
351                 for (int i = 0;i < rentry.inferenceOutputs.size();++i) {
352                     jbyteArray inferenceOutput = nullptr;
353                     inferenceOutput = env->NewByteArray(rentry.inferenceOutputs[i].size());
354                     if (env->ExceptionCheck()) { return false; }
355                     jbyte *bytes = env->GetByteArrayElements(inferenceOutput, nullptr);
356                     memcpy(bytes, &rentry.inferenceOutputs[i][0], rentry.inferenceOutputs[i].size());
357                     env->ReleaseByteArrayElements(inferenceOutput, bytes, 0);
358                     env->SetObjectArrayElement(inferenceOutputs, i, inferenceOutput);
359                 }
360             }
361 
362             jobject object = env->NewObject(
363                 result_class, result_ctor, rentry.computeTimeSec,
364                 meanSquareErrorArray, maxSingleErrorArray, inferenceOutputs,
365                 rentry.inputOutputSequenceIndex, rentry.inputOutputIndex);
366             if (env->ExceptionCheck() || object == NULL) { return false; }
367 
368             env->CallBooleanMethod(resultList, list_add, object);
369             if (env->ExceptionCheck()) { return false; }
370 
371             // Releasing local references to objects to avoid local reference table overflow
372             // if tests is set to run for long time.
373             if (meanSquareErrorArray) {
374                 env->DeleteLocalRef(meanSquareErrorArray);
375             }
376             if (maxSingleErrorArray) {
377                 env->DeleteLocalRef(maxSingleErrorArray);
378             }
379             env->DeleteLocalRef(object);
380         }
381     }
382 
383     return success;
384 }
385 
386 extern "C"
387 JNIEXPORT void
388 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(JNIEnv * env,jobject,jlong _modelHandle,jstring dumpPath,jobject inOutDataList)389 Java_com_android_nn_benchmark_core_NNTestBase_dumpAllLayers(
390         JNIEnv *env,
391         jobject /* this */,
392         jlong _modelHandle,
393         jstring dumpPath,
394         jobject inOutDataList) {
395 
396     BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
397 
398     InferenceInOutSequenceList data(env, inOutDataList, /*expectGoldenOutputs=*/false);
399     if (!data.isValid()) {
400         return;
401     }
402 
403     const char *dumpPathStr = env->GetStringUTFChars(dumpPath, JNI_FALSE);
404     model->dumpAllLayers(dumpPathStr, data.data());
405     env->ReleaseStringUTFChars(dumpPath, dumpPathStr);
406 }
407 
408 extern "C"
409 JNIEXPORT jboolean
410 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_hasAccelerator()411 Java_com_android_nn_benchmark_core_NNTestBase_hasAccelerator() {
412   uint32_t device_count = 0;
413   NnApiImplementation()->ANeuralNetworks_getDeviceCount(&device_count);
414   // We only consider a real device, not 'nnapi-reference'.
415   return device_count > 1;
416 }
417 
418 extern "C"
419 JNIEXPORT jboolean
420 JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_getAcceleratorNames(JNIEnv * env,jclass,jobject resultList)421 Java_com_android_nn_benchmark_core_NNTestBase_getAcceleratorNames(
422     JNIEnv *env,
423     jclass, /* clazz */
424     jobject resultList
425     ) {
426   uint32_t device_count = 0;
427   auto nnapi_result = NnApiImplementation()->ANeuralNetworks_getDeviceCount(&device_count);
428   if (nnapi_result != 0) {
429     return false;
430   }
431 
432   jclass list_class = env->FindClass("java/util/List");
433   if (list_class == nullptr) { return false; }
434   jmethodID list_add = env->GetMethodID(list_class, "add", "(Ljava/lang/Object;)Z");
435   if (list_add == nullptr) { return false; }
436 
437   for (int i = 0; i < device_count; i++) {
438       ANeuralNetworksDevice* device = nullptr;
439       nnapi_result = NnApiImplementation()->ANeuralNetworks_getDevice(i, &device);
440       if (nnapi_result != 0) {
441           return false;
442        }
443       const char* buffer = nullptr;
444       nnapi_result = NnApiImplementation()->ANeuralNetworksDevice_getName(device, &buffer);
445       if (nnapi_result != 0) {
446         return false;
447       }
448 
449       auto device_name = env->NewStringUTF(buffer);
450 
451       env->CallBooleanMethod(resultList, list_add, device_name);
452       if (env->ExceptionCheck()) { return false; }
453   }
454   return true;
455 }
456 
convertToJfloatArray(JNIEnv * env,const std::vector<float> & from)457 static jfloatArray convertToJfloatArray(JNIEnv* env, const std::vector<float>& from) {
458   jfloatArray to = env->NewFloatArray(from.size());
459   if (env->ExceptionCheck()) {
460     return nullptr;
461   }
462   jfloat* bytes = env->GetFloatArrayElements(to, nullptr);
463   memcpy(bytes, from.data(), from.size() * sizeof(float));
464   env->ReleaseFloatArrayElements(to, bytes, 0);
465   return to;
466 }
467 
468 extern "C" JNIEXPORT jobject JNICALL
Java_com_android_nn_benchmark_core_NNTestBase_runCompilationBenchmark(JNIEnv * env,jobject,jlong _modelHandle,jint maxNumIterations,jfloat warmupTimeoutSec,jfloat runTimeoutSec)469 Java_com_android_nn_benchmark_core_NNTestBase_runCompilationBenchmark(
470     JNIEnv* env,
471     jobject /* this */,
472     jlong _modelHandle,
473     jint maxNumIterations,
474     jfloat warmupTimeoutSec,
475     jfloat runTimeoutSec) {
476   BenchmarkModel* model = reinterpret_cast<BenchmarkModel*>(_modelHandle);
477 
478   jclass result_class = env->FindClass("com/android/nn/benchmark/core/CompilationBenchmarkResult");
479   if (result_class == nullptr) return nullptr;
480   jmethodID result_ctor = env->GetMethodID(result_class, "<init>", "([F[F[FI)V");
481   if (result_ctor == nullptr) return nullptr;
482 
483   CompilationBenchmarkResult result;
484   bool success =
485           model->benchmarkCompilation(maxNumIterations, warmupTimeoutSec, runTimeoutSec, &result);
486   if (!success) return nullptr;
487 
488   // Convert cpp CompilationBenchmarkResult struct to java.
489   jfloatArray compileWithoutCacheArray =
490           convertToJfloatArray(env, result.compileWithoutCacheTimeSec);
491   if (compileWithoutCacheArray == nullptr) return nullptr;
492 
493   // saveToCache and prepareFromCache results may not exist.
494   jfloatArray saveToCacheArray = nullptr;
495   if (result.saveToCacheTimeSec) {
496     saveToCacheArray = convertToJfloatArray(env, result.saveToCacheTimeSec.value());
497     if (saveToCacheArray == nullptr) return nullptr;
498   }
499   jfloatArray prepareFromCacheArray = nullptr;
500   if (result.prepareFromCacheTimeSec) {
501     prepareFromCacheArray = convertToJfloatArray(env, result.prepareFromCacheTimeSec.value());
502     if (prepareFromCacheArray == nullptr) return nullptr;
503   }
504 
505   jobject object = env->NewObject(result_class, result_ctor, compileWithoutCacheArray,
506                                   saveToCacheArray, prepareFromCacheArray, result.cacheSizeBytes);
507   if (env->ExceptionCheck()) return nullptr;
508   return object;
509 }
510