1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.util; 18 19 import android.app.Activity; 20 import android.os.Bundle; 21 import android.util.Log; 22 23 import com.android.nn.benchmark.core.NNTestBase; 24 import com.android.nn.benchmark.core.TestModels; 25 import com.android.nn.benchmark.core.TestModels.TestModelEntry; 26 27 import java.io.File; 28 29 30 /** 31 * Helper activity for dumping state of interference intermediate tensors. 32 * 33 * Example usage: 34 * adb shell am start -n com.android.nn.benchmark.app/com.android.nn.benchmark.\ 35 * util.DumpIntermediateTensors --es modelName mobilenet_v1_1.0_224_quant_topk_aosp,tts_float\ 36 * inputAssetIndex 0 37 * 38 * Assets will be then dumped into /data/data/com.android.nn.benchmark.app/files/intermediate 39 * To fetch: 40 * adb pull /data/data/com.android.nn.benchmark.app/files/intermediate 41 */ 42 public class DumpIntermediateTensors extends Activity { 43 protected static final String TAG = "VDEBUG"; 44 public static final String EXTRA_MODEL_NAME = "modelName"; 45 public static final String EXTRA_INPUT_ASSET_INDEX = "inputAssetIndex"; 46 public static final String EXTRA_INPUT_ASSET_SIZE = "inputAssetSize"; 47 public static final String DUMP_DIR = "intermediate"; 48 public static final String CPU_DIR = "cpu"; 49 public static final String NNAPI_DIR = "nnapi"; 50 // TODO(veralin): Update to use other models in vendor as well. 51 // Due to recent change in NNScoringTest, the model names are moved to here. 52 private static final String[] MODEL_NAMES = new String[]{ 53 "tts_float", 54 "asr_float", 55 "mobilenet_v1_1.0_224_quant_topk_aosp", 56 "mobilenet_v1_1.0_224_topk_aosp", 57 "mobilenet_v1_0.75_192_quant_topk_aosp", 58 "mobilenet_v1_0.75_192_topk_aosp", 59 "mobilenet_v1_0.5_160_quant_topk_aosp", 60 "mobilenet_v1_0.5_160_topk_aosp", 61 "mobilenet_v1_0.25_128_quant_topk_aosp", 62 "mobilenet_v1_0.25_128_topk_aosp", 63 "mobilenet_v2_0.35_128_topk_aosp", 64 "mobilenet_v2_0.5_160_topk_aosp", 65 "mobilenet_v2_0.75_192_topk_aosp", 66 "mobilenet_v2_1.0_224_topk_aosp", 67 "mobilenet_v2_1.0_224_quant_topk_aosp", 68 }; 69 70 @Override onCreate(Bundle savedInstanceState)71 protected void onCreate(Bundle savedInstanceState) { 72 super.onCreate(savedInstanceState); 73 Bundle extras = getIntent().getExtras(); 74 75 String userModelName = extras.getString(EXTRA_MODEL_NAME); 76 int inputAssetIndex = extras.getInt(EXTRA_INPUT_ASSET_INDEX, 0); 77 int inputAssetSize = extras.getInt(EXTRA_INPUT_ASSET_SIZE, 1); 78 79 // Default to run all models in NNScoringTest 80 String[] modelNames = userModelName == null ? MODEL_NAMES : userModelName.split(","); 81 82 try { 83 File dumpDir = new File(getFilesDir(), DUMP_DIR); 84 safeMkdir(dumpDir); 85 86 for (String modelName : modelNames) { 87 File modelDir = new File(getFilesDir() + "/" + DUMP_DIR, modelName); 88 safeMkdir(modelDir); 89 // Run in CPU and NNAPI mode 90 for (final boolean useNNAPI : new boolean[]{false, true}) { 91 String useNNAPIDir = useNNAPI ? NNAPI_DIR : CPU_DIR; 92 TestModelEntry modelEntry = TestModels.getModelByName(modelName); 93 NNTestBase testBase = modelEntry.createNNTestBase( 94 useNNAPI, /*enableIntermediateTensorsDump*/true, /*mmapModel*/false); 95 testBase.setupModel(this); 96 File outputDir = new File(getFilesDir() + "/" + DUMP_DIR + 97 "/" + modelName, useNNAPIDir); 98 safeMkdir(outputDir); 99 testBase.dumpAllLayers(outputDir, inputAssetIndex, inputAssetSize); 100 } 101 } 102 103 } catch (Exception e) { 104 Log.e(TAG, "Failed to dump tensors", e); 105 throw new IllegalStateException("Failed to dump tensors", e); 106 } 107 finish(); 108 } 109 deleteRecursive(File fileOrDirectory)110 private void deleteRecursive(File fileOrDirectory) { 111 if (fileOrDirectory.isDirectory()) { 112 for (File child : fileOrDirectory.listFiles()) { 113 deleteRecursive(child); 114 } 115 } 116 fileOrDirectory.delete(); 117 } 118 safeMkdir(File fileOrDirectory)119 private void safeMkdir(File fileOrDirectory) { 120 deleteRecursive(fileOrDirectory); 121 fileOrDirectory.mkdir(); 122 } 123 } 124