1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.annotation.SuppressLint; 20 import android.content.Context; 21 import android.content.res.AssetManager; 22 import android.os.Build; 23 import android.util.Log; 24 import android.util.Pair; 25 import android.widget.TextView; 26 27 import java.io.File; 28 import java.io.FileOutputStream; 29 import java.io.IOException; 30 import java.io.InputStream; 31 import java.util.ArrayList; 32 import java.util.Collections; 33 import java.util.List; 34 import java.util.Optional; 35 import java.util.Random; 36 import java.util.stream.Collectors; 37 38 public class NNTestBase implements AutoCloseable { 39 protected static final String TAG = "NN_TESTBASE"; 40 41 // Used to load the 'native-lib' library on application startup. 42 static { 43 System.loadLibrary("nnbenchmark_jni"); 44 } 45 46 // Does the device has any NNAPI accelerator? 47 // We only consider a real device, not 'nnapi-reference'. hasAccelerator()48 public static native boolean hasAccelerator(); 49 50 /** 51 * Fills resultList with the name of the available NNAPI accelerators 52 * 53 * @return False if any error occurred, true otherwise 54 */ getAcceleratorNames(List<String> resultList)55 private static native boolean getAcceleratorNames(List<String> resultList); 56 initModel( String modelFileName, boolean useNNApi, boolean enableIntermediateTensorsDump, String nnApiDeviceName, boolean mmapModel, String nnApiCacheDir)57 private synchronized native long initModel( 58 String modelFileName, 59 boolean useNNApi, 60 boolean enableIntermediateTensorsDump, 61 String nnApiDeviceName, 62 boolean mmapModel, 63 String nnApiCacheDir) throws NnApiDelegationFailure; 64 destroyModel(long modelHandle)65 private synchronized native void destroyModel(long modelHandle); 66 resizeInputTensors(long modelHandle, int[] inputShape)67 private synchronized native boolean resizeInputTensors(long modelHandle, int[] inputShape); 68 runBenchmark(long modelHandle, List<InferenceInOutSequence> inOutList, List<InferenceResult> resultList, int inferencesSeqMaxCount, float timeoutSec, int flags)69 private synchronized native boolean runBenchmark(long modelHandle, 70 List<InferenceInOutSequence> inOutList, 71 List<InferenceResult> resultList, 72 int inferencesSeqMaxCount, 73 float timeoutSec, 74 int flags); 75 runCompilationBenchmark( long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec)76 private synchronized native CompilationBenchmarkResult runCompilationBenchmark( 77 long modelHandle, int maxNumIterations, float warmupTimeoutSec, float runTimeoutSec); 78 dumpAllLayers( long modelHandle, String dumpPath, List<InferenceInOutSequence> inOutList)79 private synchronized native void dumpAllLayers( 80 long modelHandle, 81 String dumpPath, 82 List<InferenceInOutSequence> inOutList); 83 availableAcceleratorNames()84 public static List<String> availableAcceleratorNames() { 85 List<String> availableAccelerators = new ArrayList<>(); 86 if (NNTestBase.getAcceleratorNames(availableAccelerators)) { 87 return availableAccelerators.stream().filter( 88 acceleratorName -> !acceleratorName.equalsIgnoreCase( 89 "nnapi-reference")).collect(Collectors.toList()); 90 } else { 91 Log.e(TAG, "Unable to retrieve accelerator names!!"); 92 return Collections.EMPTY_LIST; 93 } 94 } 95 96 /** Discard inference output in inference results. */ 97 public static final int FLAG_DISCARD_INFERENCE_OUTPUT = 1 << 0; 98 /** 99 * Do not expect golden outputs with inference inputs. 100 * 101 * Useful in cases where there's no straightforward golden output values 102 * for the benchmark. This will also skip calculating basic (golden 103 * output based) error metrics. 104 */ 105 public static final int FLAG_IGNORE_GOLDEN_OUTPUT = 1 << 1; 106 107 108 protected Context mContext; 109 protected TextView mText; 110 private final String mModelName; 111 private final String mModelFile; 112 private long mModelHandle; 113 private final int[] mInputShape; 114 private final InferenceInOutSequence.FromAssets[] mInputOutputAssets; 115 private final InferenceInOutSequence.FromDataset[] mInputOutputDatasets; 116 private final EvaluatorConfig mEvaluatorConfig; 117 private EvaluatorInterface mEvaluator; 118 private boolean mHasGoldenOutputs; 119 private boolean mUseNNApi = false; 120 private boolean mEnableIntermediateTensorsDump = false; 121 private final int mMinSdkVersion; 122 private Optional<String> mNNApiDeviceName = Optional.empty(); 123 private boolean mMmapModel = false; 124 // Path where the current model has been stored for execution 125 private String mTemporaryModelFilePath; 126 NNTestBase(String modelName, String modelFile, int[] inputShape, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets, EvaluatorConfig evaluator, int minSdkVersion)127 public NNTestBase(String modelName, String modelFile, int[] inputShape, 128 InferenceInOutSequence.FromAssets[] inputOutputAssets, 129 InferenceInOutSequence.FromDataset[] inputOutputDatasets, 130 EvaluatorConfig evaluator, int minSdkVersion) { 131 if (inputOutputAssets == null && inputOutputDatasets == null) { 132 throw new IllegalArgumentException( 133 "Neither inputOutputAssets or inputOutputDatasets given - no inputs"); 134 } 135 if (inputOutputAssets != null && inputOutputDatasets != null) { 136 throw new IllegalArgumentException( 137 "Both inputOutputAssets or inputOutputDatasets given. Only one" + 138 "supported at once."); 139 } 140 mModelName = modelName; 141 mModelFile = modelFile; 142 mInputShape = inputShape; 143 mInputOutputAssets = inputOutputAssets; 144 mInputOutputDatasets = inputOutputDatasets; 145 mModelHandle = 0; 146 mEvaluatorConfig = evaluator; 147 mMinSdkVersion = minSdkVersion; 148 } 149 useNNApi()150 public void useNNApi() { 151 useNNApi(true); 152 } 153 useNNApi(boolean value)154 public void useNNApi(boolean value) { 155 mUseNNApi = value; 156 } 157 enableIntermediateTensorsDump()158 public void enableIntermediateTensorsDump() { 159 enableIntermediateTensorsDump(true); 160 } 161 enableIntermediateTensorsDump(boolean value)162 public void enableIntermediateTensorsDump(boolean value) { 163 mEnableIntermediateTensorsDump = value; 164 } 165 setNNApiDeviceName(String value)166 public void setNNApiDeviceName(String value) { 167 if (!mUseNNApi) { 168 Log.e(TAG, "Setting device name has no effect when not using NNAPI"); 169 } 170 mNNApiDeviceName = Optional.ofNullable(value); 171 } 172 setMmapModel(boolean value)173 public void setMmapModel(boolean value) { 174 mMmapModel = value; 175 } 176 setupModel(Context ipcxt)177 public final boolean setupModel(Context ipcxt) throws IOException, NnApiDelegationFailure { 178 mContext = ipcxt; 179 if (mTemporaryModelFilePath != null) { 180 deleteOrWarn(mTemporaryModelFilePath); 181 } 182 mTemporaryModelFilePath = copyAssetToFile(); 183 String nnApiCacheDir = mContext.getCodeCacheDir().toString(); 184 mModelHandle = initModel( 185 mTemporaryModelFilePath, mUseNNApi, mEnableIntermediateTensorsDump, 186 mNNApiDeviceName.orElse(null), mMmapModel, nnApiCacheDir); 187 if (mModelHandle == 0) { 188 Log.e(TAG, "Failed to init the model"); 189 return false; 190 } 191 resizeInputTensors(mModelHandle, mInputShape); 192 193 if (mEvaluatorConfig != null) { 194 mEvaluator = mEvaluatorConfig.createEvaluator(mContext.getAssets()); 195 } 196 return true; 197 } 198 getTestInfo()199 public String getTestInfo() { 200 return mModelName; 201 } 202 getEvaluator()203 public EvaluatorInterface getEvaluator() { 204 return mEvaluator; 205 } 206 checkSdkVersion()207 public void checkSdkVersion() throws UnsupportedSdkException { 208 if (mMinSdkVersion > 0 && Build.VERSION.SDK_INT < mMinSdkVersion) { 209 throw new UnsupportedSdkException("SDK version not supported. Mininum required: " + 210 mMinSdkVersion + ", current version: " + Build.VERSION.SDK_INT); 211 } 212 } 213 deleteOrWarn(String path)214 private void deleteOrWarn(String path) { 215 if (!new File(path).delete()) { 216 Log.w(TAG, String.format( 217 "Unable to delete file '%s'. This might cause device to run out of space.", 218 path)); 219 } 220 } 221 222 getInputOutputAssets()223 private List<InferenceInOutSequence> getInputOutputAssets() throws IOException { 224 // TODO: Caching, don't read inputs for every inference 225 List<InferenceInOutSequence> inOutList = 226 getInputOutputAssets(mContext, mInputOutputAssets, mInputOutputDatasets); 227 228 Boolean lastGolden = null; 229 for (InferenceInOutSequence sequence : inOutList) { 230 mHasGoldenOutputs = sequence.hasGoldenOutput(); 231 if (lastGolden == null) { 232 lastGolden = mHasGoldenOutputs; 233 } else { 234 if (lastGolden != mHasGoldenOutputs) { 235 throw new IllegalArgumentException( 236 "Some inputs for " + mModelName + " have outputs while some don't."); 237 } 238 } 239 } 240 return inOutList; 241 } 242 getInputOutputAssets(Context context, InferenceInOutSequence.FromAssets[] inputOutputAssets, InferenceInOutSequence.FromDataset[] inputOutputDatasets)243 public static List<InferenceInOutSequence> getInputOutputAssets(Context context, 244 InferenceInOutSequence.FromAssets[] inputOutputAssets, 245 InferenceInOutSequence.FromDataset[] inputOutputDatasets) throws IOException { 246 // TODO: Caching, don't read inputs for every inference 247 List<InferenceInOutSequence> inOutList = new ArrayList<>(); 248 if (inputOutputAssets != null) { 249 for (InferenceInOutSequence.FromAssets ioAsset : inputOutputAssets) { 250 inOutList.add(ioAsset.readAssets(context.getAssets())); 251 } 252 } 253 if (inputOutputDatasets != null) { 254 for (InferenceInOutSequence.FromDataset dataset : inputOutputDatasets) { 255 inOutList.addAll(dataset.readDataset(context.getAssets(), context.getCacheDir())); 256 } 257 } 258 259 return inOutList; 260 } 261 getDefaultFlags()262 public int getDefaultFlags() { 263 int flags = 0; 264 if (!mHasGoldenOutputs) { 265 flags = flags | FLAG_IGNORE_GOLDEN_OUTPUT; 266 } 267 if (mEvaluator == null) { 268 flags = flags | FLAG_DISCARD_INFERENCE_OUTPUT; 269 } 270 return flags; 271 } 272 dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize)273 public void dumpAllLayers(File dumpDir, int inputAssetIndex, int inputAssetSize) 274 throws IOException { 275 if (!dumpDir.exists() || !dumpDir.isDirectory()) { 276 throw new IllegalArgumentException("dumpDir doesn't exist or is not a directory"); 277 } 278 if (!mEnableIntermediateTensorsDump) { 279 throw new IllegalStateException("mEnableIntermediateTensorsDump is " + 280 "set to false, impossible to proceed"); 281 } 282 283 List<InferenceInOutSequence> ios = getInputOutputAssets(); 284 dumpAllLayers(mModelHandle, dumpDir.toString(), 285 ios.subList(inputAssetIndex, inputAssetSize)); 286 } 287 runInferenceOnce()288 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runInferenceOnce() 289 throws IOException, BenchmarkException { 290 List<InferenceInOutSequence> ios = getInputOutputAssets(); 291 int flags = getDefaultFlags(); 292 Pair<List<InferenceInOutSequence>, List<InferenceResult>> output = 293 runBenchmark(ios, 1, Float.MAX_VALUE, flags); 294 return output; 295 } 296 runBenchmark(float timeoutSec)297 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark(float timeoutSec) 298 throws IOException, BenchmarkException { 299 // Run as many as possible before timeout. 300 int flags = getDefaultFlags(); 301 return runBenchmark(getInputOutputAssets(), 0xFFFFFFF, timeoutSec, flags); 302 } 303 304 /** Run through whole input set (once or mutliple times). */ runBenchmarkCompleteInputSet( int setRepeat, float timeoutSec)305 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmarkCompleteInputSet( 306 int setRepeat, 307 float timeoutSec) 308 throws IOException, BenchmarkException { 309 int flags = getDefaultFlags(); 310 List<InferenceInOutSequence> ios = getInputOutputAssets(); 311 int totalSequenceInferencesCount = ios.size() * setRepeat; 312 int extpectedResults = 0; 313 for (InferenceInOutSequence iosSeq : ios) { 314 extpectedResults += iosSeq.size(); 315 } 316 extpectedResults *= setRepeat; 317 318 Pair<List<InferenceInOutSequence>, List<InferenceResult>> result = 319 runBenchmark(ios, totalSequenceInferencesCount, timeoutSec, 320 flags); 321 if (result.second.size() != extpectedResults) { 322 // We reached a timeout or failed to evaluate whole set for other reason, abort. 323 final String errorMsg = "Failed to evaluate complete input set, expected: " 324 + extpectedResults + 325 ", received: " + result.second.size(); 326 Log.w(TAG, errorMsg); 327 throw new IllegalStateException(errorMsg); 328 } 329 return result; 330 } 331 runBenchmark( List<InferenceInOutSequence> inOutList, int inferencesSeqMaxCount, float timeoutSec, int flags)332 public Pair<List<InferenceInOutSequence>, List<InferenceResult>> runBenchmark( 333 List<InferenceInOutSequence> inOutList, 334 int inferencesSeqMaxCount, 335 float timeoutSec, 336 int flags) 337 throws IOException, BenchmarkException { 338 if (mModelHandle == 0) { 339 throw new UnsupportedModelException("Unsupported model"); 340 } 341 List<InferenceResult> resultList = new ArrayList<>(); 342 if (!runBenchmark(mModelHandle, inOutList, resultList, inferencesSeqMaxCount, 343 timeoutSec, flags)) { 344 throw new BenchmarkException("Failed to run benchmark"); 345 } 346 return new Pair<List<InferenceInOutSequence>, List<InferenceResult>>( 347 inOutList, resultList); 348 } 349 runCompilationBenchmark(float warmupTimeoutSec, float runTimeoutSec, int maxIterations)350 public CompilationBenchmarkResult runCompilationBenchmark(float warmupTimeoutSec, 351 float runTimeoutSec, int maxIterations) throws IOException, BenchmarkException { 352 if (mModelHandle == 0) { 353 throw new UnsupportedModelException("Unsupported model"); 354 } 355 CompilationBenchmarkResult result = runCompilationBenchmark( 356 mModelHandle, maxIterations, warmupTimeoutSec, runTimeoutSec); 357 if (result == null) { 358 throw new BenchmarkException("Failed to run compilation benchmark"); 359 } 360 return result; 361 } 362 destroy()363 public void destroy() { 364 if (mModelHandle != 0) { 365 destroyModel(mModelHandle); 366 mModelHandle = 0; 367 } 368 if (mTemporaryModelFilePath != null) { 369 deleteOrWarn(mTemporaryModelFilePath); 370 mTemporaryModelFilePath = null; 371 } 372 } 373 374 private final Random mRandom = new Random(System.currentTimeMillis()); 375 376 // We need to copy it to cache dir, so that TFlite can load it directly. copyAssetToFile()377 private String copyAssetToFile() throws IOException { 378 @SuppressLint("DefaultLocale") 379 String outFileName = 380 String.format("%s/%s-%d-%d.tflite", mContext.getCacheDir().getAbsolutePath(), 381 mModelFile, 382 Thread.currentThread().getId(), mRandom.nextInt(10000)); 383 384 copyAssetToFile(mContext, mModelFile + ".tflite", outFileName); 385 return outFileName; 386 } 387 copyModelToFile(Context context, String modelFileName, File targetFile)388 public static boolean copyModelToFile(Context context, String modelFileName, File targetFile) 389 throws IOException { 390 if (!targetFile.exists() && !targetFile.createNewFile()) { 391 Log.w(TAG, String.format("Unable to create file %s", targetFile.getAbsolutePath())); 392 return false; 393 } 394 NNTestBase.copyAssetToFile(context, modelFileName, targetFile.getAbsolutePath()); 395 return true; 396 } 397 copyAssetToFile(Context context, String modelAssetName, String targetPath)398 public static void copyAssetToFile(Context context, String modelAssetName, String targetPath) 399 throws IOException { 400 AssetManager assetManager = context.getAssets(); 401 try { 402 File outFile = new File(targetPath); 403 404 try (InputStream in = assetManager.open(modelAssetName); 405 FileOutputStream out = new FileOutputStream(outFile)) { 406 byte[] byteBuffer = new byte[1024]; 407 int readBytes = -1; 408 while ((readBytes = in.read(byteBuffer)) != -1) { 409 out.write(byteBuffer, 0, readBytes); 410 } 411 } 412 } catch (IOException e) { 413 Log.e(TAG, "Failed to copy asset file: " + modelAssetName, e); 414 throw e; 415 } 416 } 417 418 @Override close()419 public void close() { 420 destroy(); 421 } 422 } 423