1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.annotation.SuppressLint;
20 import android.content.Context;
21 import android.content.res.AssetManager;
22 import android.os.Build;
23 import android.util.Log;
24 import android.util.Pair;
25 import android.widget.TextView;
26 
27 import java.io.File;
28 import java.io.FileOutputStream;
29 import java.io.IOException;
30 import java.io.InputStream;
31 import java.util.ArrayList;
32 import java.util.Collections;
33 import java.util.List;
34 import java.util.Optional;
35 import java.util.Random;
36 import java.util.stream.Collectors;
37 
38 public class NNTestBase implements AutoCloseable {
39     protected static final String TAG = "NN_TESTBASE";
40 
41     // Used to load the 'native-lib' library on application startup.
42     static {
43         System.loadLibrary("nnbenchmark_jni");
44     }
45 
46     // Does the device has any NNAPI accelerator?
47     // We only consider a real device, not 'nnapi-reference'.
hasAccelerator()48     public static native boolean hasAccelerator();
49 
50     /**
51      * Fills resultList with the name of the available NNAPI accelerators
52      *
53      * @return False if any error occurred, true otherwise
54      */
getAcceleratorNames(List<String> resultList)55     private static native boolean getAcceleratorNames(List<String> resultList);
56 
initModel( String modelFileName, boolean useNNApi, boolean enableIntermediateTensorsDump, String nnApiDeviceName, boolean mmapModel, String nnApiCacheDir)57     private synchronized native long initModel(
58             String modelFileName,
59             boolean useNNApi,
60             boolean enableIntermediateTensorsDump,
61             String nnApiDeviceName,
62             boolean mmapModel,
63             String nnApiCacheDir) throws NnApiDelegationFailure;
64 
destroyModel(long modelHandle)65     private synchronized native void destroyModel(long modelHandle);
66 
resizeInputTensors(long modelHandle, int[] inputShape)67     private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape);
68 
runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)69     private synchronized native boolean runBenchmark(long modelHandle,
70             List<InferenceInOutSequence> inOutList,
71             List<InferenceResult> resultList,
72             int inferencesSeqMaxCount,
73             float timeoutSec,
74             int flags);
75 
runCompilationBenchmark( long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec)76     private synchronized native CompilationBenchmarkResult runCompilationBenchmark(
77             long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec);
78 
dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)79     private synchronized native void dumpAllLayers(
80             long modelHandle,
81             String dumpPath,
82             List<InferenceInOutSequence> inOutList);
83 
availableAcceleratorNames()84     public static List<String> availableAcceleratorNames() {
85         List<String> availableAccelerators = new ArrayList<>();
86         if (NNTestBase.getAcceleratorNames(availableAccelerators)) {
87             return availableAccelerators.stream().filter(
88                     acceleratorName -> !acceleratorName.equalsIgnoreCase(
89                             "nnapi-reference")).collect(Collectors.toList());
90         } else {
91             Log.e(TAG, "Unable to retrieve accelerator names!!");
92             return Collections.EMPTY_LIST;
93         }
94     }
95 
96     /** Discard inference output in inference results. */
97     public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0;
98     /**
99      * Do not expect golden outputs with inference inputs.
100      *
101      * Useful in cases where there's no straightforward golden output values
102      * for the benchmark. This will also skip calculating basic (golden
103      * output based) error metrics.
104      */
105     public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1;
106 
107 
108     protected Context mContext;
109     protected TextView mText;
110     private final String mModelName;
111     private final String mModelFile;
112     private long mModelHandle;
113     private final int[] mInputShape;
114     private final InferenceInOutSequence.FromAssets[] mInputOutputAssets;
115     private final InferenceInOutSequence.FromDataset[] mInputOutputDatasets;
116     private final EvaluatorConfig mEvaluatorConfig;
117     private EvaluatorInterface mEvaluator;
118     private boolean mHasGoldenOutputs;
119     private boolean mUseNNApi = false;
120     private boolean mEnableIntermediateTensorsDump = false;
121     private final int mMinSdkVersion;
122     private Optional<String> mNNApiDeviceName = Optional.empty();
123     private boolean mMmapModel = false;
124     // Path where the current model has been stored for execution
125     private String mTemporaryModelFilePath;
126 
NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)127     public NNTestBase(String modelName, String modelFile, int[] inputShape,
128             InferenceInOutSequence.FromAssets[] inputOutputAssets,
129             InferenceInOutSequence.FromDataset[] inputOutputDatasets,
130             EvaluatorConfig evaluator, int minSdkVersion) {
131         if (inputOutputAssets == null && inputOutputDatasets == null) {
132             throw new IllegalArgumentException(
133                     "Neither inputOutputAssets or inputOutputDatasets given - no inputs");
134         }
135         if (inputOutputAssets != null && inputOutputDatasets != null) {
136             throw new IllegalArgumentException(
137                     "Both inputOutputAssets or inputOutputDatasets given. Only one" +
138                             "supported at once.");
139         }
140         mModelName = modelName;
141         mModelFile = modelFile;
142         mInputShape = inputShape;
143         mInputOutputAssets = inputOutputAssets;
144         mInputOutputDatasets = inputOutputDatasets;
145         mModelHandle = 0;
146         mEvaluatorConfig = evaluator;
147         mMinSdkVersion = minSdkVersion;
148     }
149 
useNNApi()150     public void useNNApi() {
151         useNNApi(true);
152     }
153 
useNNApi(boolean value)154     public void useNNApi(boolean value) {
155         mUseNNApi = value;
156     }
157 
enableIntermediateTensorsDump()158     public void enableIntermediateTensorsDump() {
159         enableIntermediateTensorsDump(true);
160     }
161 
enableIntermediateTensorsDump(boolean value)162     public void enableIntermediateTensorsDump(boolean value) {
163         mEnableIntermediateTensorsDump = value;
164     }
165 
setNNApiDeviceName(String value)166     public void setNNApiDeviceName(String value) {
167         if (!mUseNNApi) {
168             Log.e(TAG, "Setting device name has no effect when not using NNAPI");
169         }
170         mNNApiDeviceName = Optional.ofNullable(value);
171     }
172 
setMmapModel(boolean value)173     public void setMmapModel(boolean value) {
174         mMmapModel = value;
175     }
176 
setupModel(Context ipcxt)177     public final boolean setupModel(Context ipcxt) throws IOException, NnApiDelegationFailure {
178         mContext = ipcxt;
179         if (mTemporaryModelFilePath != null) {
180             deleteOrWarn(mTemporaryModelFilePath);
181         }
182         mTemporaryModelFilePath = copyAssetToFile();
183         String nnApiCacheDir = mContext.getCodeCacheDir().toString();
184         mModelHandle = initModel(
185                 mTemporaryModelFilePath, mUseNNApi, mEnableIntermediateTensorsDump,
186                 mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir);
187         if (mModelHandle == 0) {
188             Log.e(TAG, "Failed to init the model");
189             return false;
190         }
191         resizeInputTensors(mModelHandle, mInputShape);
192 
193         if (mEvaluatorConfig != null) {
194             mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets());
195         }
196         return true;
197     }
198 
getTestInfo()199     public String getTestInfo() {
200         return mModelName;
201     }
202 
getEvaluator()203     public EvaluatorInterface getEvaluator() {
204         return mEvaluator;
205     }
206 
checkSdkVersion()207     public void checkSdkVersion() throws UnsupportedSdkException {
208         if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) {
209             throw new UnsupportedSdkException("SDK version not supported. Mininum required: " +
210                     mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT);
211         }
212     }
213 
deleteOrWarn(String path)214     private void deleteOrWarn(String path) {
215         if (!new File(path).delete()) {
216             Log.w(TAG, String.format(
217                     "Unable to delete file '%s'. This might cause device to run out of space.",
218                     path));
219         }
220     }
221 
222 
getInputOutputAssets()223     private List<InferenceInOutSequence> getInputOutputAssets() throws IOException {
224         // TODO: Caching, don't read inputs for every inference
225         List<InferenceInOutSequence> inOutList =
226                 getInputOutputAssets(mContext, mInputOutputAssets, mInputOutputDatasets);
227 
228         Boolean lastGolden = null;
229         for (InferenceInOutSequence sequence : inOutList) {
230             mHasGoldenOutputs = sequence.hasGoldenOutput();
231             if (lastGolden == null) {
232                 lastGolden = mHasGoldenOutputs;
233             } else {
234                 if (lastGolden != mHasGoldenOutputs) {
235                     throw new IllegalArgumentException(
236                             "Some inputs for " + mModelName + " have outputs while some don't.");
237                 }
238             }
239         }
240         return inOutList;
241     }
242 
getInputOutputAssets(Context context, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets)243     public static List<InferenceInOutSequence> getInputOutputAssets(Context context,
244             InferenceInOutSequence.FromAssets[] inputOutputAssets,
245             InferenceInOutSequence.FromDataset[] inputOutputDatasets) throws IOException {
246         // TODO: Caching, don't read inputs for every inference
247         List<InferenceInOutSequence> inOutList = new ArrayList<>();
248         if (inputOutputAssets != null) {
249             for (InferenceInOutSequence.FromAssets ioAsset : inputOutputAssets) {
250                 inOutList.add(ioAsset.readAssets(context.getAssets()));
251             }
252         }
253         if (inputOutputDatasets != null) {
254             for (InferenceInOutSequence.FromDataset dataset : inputOutputDatasets) {
255                 inOutList.addAll(dataset.readDataset(context.getAssets(), context.getCacheDir()));
256             }
257         }
258 
259         return inOutList;
260     }
261 
getDefaultFlags()262     public int getDefaultFlags() {
263         int flags = 0;
264         if (!mHasGoldenOutputs) {
265             flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT;
266         }
267         if (mEvaluator == null) {
268             flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT;
269         }
270         return flags;
271     }
272 
dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)273     public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)
274             throws IOException {
275         if (!dumpDir.exists() || !dumpDir.isDirectory()) {
276             throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory");
277         }
278         if (!mEnableIntermediateTensorsDump) {
279             throw new IllegalStateException("mEnableIntermediateTensorsDump is " +
280                     "set to false, impossible to proceed");
281         }
282 
283         List<InferenceInOutSequence> ios = getInputOutputAssets();
284         dumpAllLayers(mModelHandle, dumpDir.toString(),
285                 ios.subList(inputAssetIndex, inputAssetSize));
286     }
287 
runInferenceOnce()288     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce()
289             throws IOException, BenchmarkException {
290         List<InferenceInOutSequence> ios = getInputOutputAssets();
291         int flags = getDefaultFlags();
292         Pair<List<InferenceInOutSequence>, List<InferenceResult>> output =
293                 runBenchmark(ios, 1, Float.MAX_VALUE, flags);
294         return output;
295     }
296 
runBenchmark(float timeoutSec)297     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec)
298             throws IOException, BenchmarkException {
299         // Run as many as possible before timeout.
300         int flags = getDefaultFlags();
301         return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags);
302     }
303 
304     /** Run through whole input set (once or mutliple times). */
runBenchmarkCompleteInputSet( int setRepeat, float timeoutSec)305     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet(
306             int setRepeat,
307             float timeoutSec)
308             throws IOException, BenchmarkException {
309         int flags = getDefaultFlags();
310         List<InferenceInOutSequence> ios = getInputOutputAssets();
311         int totalSequenceInferencesCount = ios.size() * setRepeat;
312         int extpectedResults = 0;
313         for (InferenceInOutSequence iosSeq : ios) {
314             extpectedResults += iosSeq.size();
315         }
316         extpectedResults *= setRepeat;
317 
318         Pair<List<InferenceInOutSequence>, List<InferenceResult>> result =
319                 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec,
320                         flags);
321         if (result.second.size() != extpectedResults) {
322             // We reached a timeout or failed to evaluate whole set for other reason, abort.
323             final String errorMsg = "Failed to evaluate complete input set, expected: "
324                     + extpectedResults +
325                     ", received: " + result.second.size();
326             Log.w(TAG, errorMsg);
327             throw new IllegalStateException(errorMsg);
328         }
329         return result;
330     }
331 
runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)332     public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(
333             List<InferenceInOutSequence> inOutList,
334             int inferencesSeqMaxCount,
335             float timeoutSec,
336             int flags)
337             throws IOException, BenchmarkException {
338         if (mModelHandle == 0) {
339             throw new UnsupportedModelException("Unsupported model");
340         }
341         List<InferenceResult> resultList = new ArrayList<>();
342         if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount,
343                 timeoutSec, flags)) {
344             throw new BenchmarkException("Failed to run benchmark");
345         }
346         return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>(
347                 inOutList, resultList);
348     }
349 
runCompilationBenchmark(float warmupTimeoutSec, float runTimeoutSec, int maxIterations)350     public CompilationBenchmarkResult runCompilationBenchmark(float warmupTimeoutSec,
351             float runTimeoutSec, int maxIterations) throws IOException, BenchmarkException {
352         if (mModelHandle == 0) {
353             throw new UnsupportedModelException("Unsupported model");
354         }
355         CompilationBenchmarkResult result = runCompilationBenchmark(
356                 mModelHandle, maxIterations, warmupTimeoutSec, runTimeoutSec);
357         if (result == null) {
358             throw new BenchmarkException("Failed to run compilation benchmark");
359         }
360         return result;
361     }
362 
destroy()363     public void destroy() {
364         if (mModelHandle != 0) {
365             destroyModel(mModelHandle);
366             mModelHandle = 0;
367         }
368         if (mTemporaryModelFilePath != null) {
369             deleteOrWarn(mTemporaryModelFilePath);
370             mTemporaryModelFilePath = null;
371         }
372     }
373 
374     private final Random mRandom = new Random(System.currentTimeMillis());
375 
376     // We need to copy it to cache dir, so that TFlite can load it directly.
copyAssetToFile()377     private String copyAssetToFile() throws IOException {
378         @SuppressLint("DefaultLocale")
379         String outFileName =
380                 String.format("%s/%s-%d-%d.tflite", mContext.getCacheDir().getAbsolutePath(),
381                         mModelFile,
382                         Thread.currentThread().getId(), mRandom.nextInt(10000));
383 
384         copyAssetToFile(mContext, mModelFile + ".tflite", outFileName);
385         return outFileName;
386     }
387 
copyModelToFile(Context context, String modelFileName, File targetFile)388     public static boolean copyModelToFile(Context context, String modelFileName, File targetFile)
389             throws IOException {
390         if (!targetFile.exists() && !targetFile.createNewFile()) {
391             Log.w(TAG, String.format("Unable to create file %s", targetFile.getAbsolutePath()));
392             return false;
393         }
394         NNTestBase.copyAssetToFile(context, modelFileName, targetFile.getAbsolutePath());
395         return true;
396     }
397 
copyAssetToFile(Context context, String modelAssetName, String targetPath)398     public static void copyAssetToFile(Context context, String modelAssetName, String targetPath)
399             throws IOException {
400         AssetManager assetManager = context.getAssets();
401         try {
402             File outFile = new File(targetPath);
403 
404             try (InputStream in = assetManager.open(modelAssetName);
405                  FileOutputStream out = new FileOutputStream(outFile)) {
406                 byte[] byteBuffer = new byte[1024];
407                 int readBytes = -1;
408                 while ((readBytes = in.read(byteBuffer)) != -1) {
409                     out.write(byteBuffer, 0, readBytes);
410                 }
411             }
412         } catch (IOException e) {
413             Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e);
414             throw e;
415         }
416     }
417 
418     @Override
close()419     public void close()  {
420         destroy();
421     }
422 }
423