1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.content.res.AssetManager; 20 import android.util.Log; 21 22 import org.json.JSONArray; 23 import org.json.JSONException; 24 import org.json.JSONObject; 25 26 import java.io.IOException; 27 import java.io.InputStream; 28 import java.io.InputStreamReader; 29 import java.io.Reader; 30 31 /** Helper class to register test model definitions from assets data */ 32 public class TestModelsListLoader { 33 private static final String TAG = "NN_BENCHMARK"; 34 35 /** 36 * Parse list of models in form of json data. 37 * 38 * Example input: 39 * { "models" : [ 40 * {"name" : "modelName", 41 * "testName" : "testName", 42 * "baselineSec" : 0.03, 43 * "evaluator": "TopK", 44 * "inputSize" : [1,2,3,4], 45 * "dataSize" : 4, 46 * "inputOutputs" : [ {"input": "input1", "output": "output2"} ] 47 * } 48 * ]} 49 */ parseJSONModelsList(String jsonStringInput)50 static public void parseJSONModelsList(String jsonStringInput) throws JSONException { 51 JSONObject jsonRootObject = new JSONObject(jsonStringInput); 52 JSONArray jsonModelsArray = jsonRootObject.getJSONArray("models"); 53 54 for (int i = 0; i < jsonModelsArray.length(); i++) { 55 JSONObject jsonTestModelEntry = jsonModelsArray.getJSONObject(i); 56 57 String name = jsonTestModelEntry.getString("name"); 58 String testName = name; 59 if (jsonTestModelEntry.has("testName")) { 60 testName = jsonTestModelEntry.getString("testName"); 61 } 62 String modelFile = name; 63 if (jsonTestModelEntry.has("modelFile")) { 64 modelFile = jsonTestModelEntry.getString("modelFile"); 65 } 66 double baseline = jsonTestModelEntry.getDouble("baselineSec"); 67 int minSdkVersion = 0; 68 if (jsonTestModelEntry.has("minSdkVersion")) { 69 minSdkVersion = jsonTestModelEntry.getInt("minSdkVersion"); 70 } 71 EvaluatorConfig evaluator = null; 72 if (jsonTestModelEntry.has("evaluator")) { 73 JSONObject evaluatorJson = jsonTestModelEntry.getJSONObject("evaluator"); 74 evaluator = new EvaluatorConfig(evaluatorJson.getString("className"), 75 evaluatorJson.has("outputMeanStdDev") 76 ? evaluatorJson.getString("outputMeanStdDev") 77 : null, 78 evaluatorJson.has("expectedTop1") 79 ? evaluatorJson.getDouble("expectedTop1") 80 : null); 81 } 82 83 int dataSize = jsonTestModelEntry.getInt("dataSize"); 84 JSONArray jsonInputSize = jsonTestModelEntry.getJSONArray("inputSize"); 85 int[] inputSize = new int[jsonInputSize.length()]; 86 int inputSizeBytes = dataSize; 87 for (int k = 0; k < jsonInputSize.length(); ++k) { 88 inputSize[k] = jsonInputSize.getInt(k); 89 inputSizeBytes *= inputSize[k]; 90 } 91 92 InferenceInOutSequence.FromAssets[] inputOutputs = null; 93 if (jsonTestModelEntry.has("inputOutputs")) { 94 JSONArray jsonInputOutputs = jsonTestModelEntry.getJSONArray("inputOutputs"); 95 inputOutputs = 96 new InferenceInOutSequence.FromAssets[jsonInputOutputs.length()]; 97 98 for (int j = 0; j < jsonInputOutputs.length(); j++) { 99 JSONObject jsonInputOutput = jsonInputOutputs.getJSONObject(j); 100 String input = jsonInputOutput.getString("input"); 101 String[] outputs = null; 102 String output = jsonInputOutput.optString("output", null); 103 if (output != null) { 104 outputs = new String[]{output}; 105 } else { 106 JSONArray outputArray = jsonInputOutput.getJSONArray("outputs"); 107 if (outputArray != null) { 108 outputs = new String[outputArray.length()]; 109 for (int k = 0; k < outputArray.length(); ++k) { 110 outputs[k] = outputArray.getString(k); 111 } 112 } 113 } 114 115 inputOutputs[j] = new InferenceInOutSequence.FromAssets(input, outputs, 116 dataSize, 117 inputSizeBytes); 118 } 119 } 120 InferenceInOutSequence.FromDataset[] datasets = null; 121 if (jsonTestModelEntry.has("dataset")) { 122 JSONObject jsonDataset = jsonTestModelEntry.getJSONObject("dataset"); 123 String inputPath = jsonDataset.getString("inputPath"); 124 String groundTruth = jsonDataset.getString("groundTruth"); 125 String labels = jsonDataset.getString("labels"); 126 String preprocessor = jsonDataset.getString("preprocessor"); 127 if (inputSize.length != 4 || inputSize[0] != 1 || inputSize[1] != inputSize[2] || 128 inputSize[3] != 3) { 129 throw new IllegalArgumentException("Datasets only support square images," + 130 "input size [1, D, D, 3], given " + inputSize[0] + 131 ", " + inputSize[1] + ", " + inputSize[2] + ", " + inputSize[3]); 132 } 133 float quantScale = 0.f; 134 float quantZeroPoint = 0.f; 135 if (dataSize == 1) { 136 if (!jsonTestModelEntry.has("inputScale") || 137 !jsonTestModelEntry.has("inputZeroPoint")) { 138 throw new IllegalArgumentException("Quantized test model must include " + 139 "inputScale and inputZeroPoint for reading a dataset"); 140 } 141 quantScale = (float) jsonTestModelEntry.getDouble("inputScale"); 142 quantZeroPoint = (float) jsonTestModelEntry.getDouble("inputZeroPoint"); 143 } 144 datasets = new InferenceInOutSequence.FromDataset[]{ 145 new InferenceInOutSequence.FromDataset(inputPath, labels, groundTruth, 146 preprocessor, dataSize, quantScale, quantZeroPoint, inputSize[1]) 147 }; 148 } 149 150 TestModels.registerModel( 151 new TestModels.TestModelEntry(name, (float) baseline, inputSize, inputOutputs, 152 datasets, testName, modelFile, evaluator, minSdkVersion, dataSize)); 153 } 154 } 155 readAssetsFileAsString(InputStream inputStream)156 static String readAssetsFileAsString(InputStream inputStream) throws IOException { 157 Reader reader = new InputStreamReader(inputStream); 158 StringBuilder sb = new StringBuilder(); 159 char buffer[] = new char[16384]; 160 int len; 161 while ((len = reader.read(buffer)) > 0) { 162 sb.append(buffer, 0, len); 163 } 164 reader.close(); 165 return sb.toString(); 166 } 167 168 /** Parse all ".json" files in root assets directory */ 169 private static final String MODELS_LIST_ROOT = "models_list"; 170 parseFromAssets(AssetManager assetManager)171 static public void parseFromAssets(AssetManager assetManager) throws IOException { 172 for (String file : assetManager.list(MODELS_LIST_ROOT)) { 173 if (!file.endsWith(".json")) { 174 continue; 175 } 176 try { 177 parseJSONModelsList(readAssetsFileAsString( 178 assetManager.open(MODELS_LIST_ROOT + "/" + file))); 179 } catch (JSONException e) { 180 Log.e(TAG, "error reading json model list", e); 181 throw new IOException("JSON error in " + file, e); 182 } catch (Exception e) { 183 Log.e(TAG, "error parsing json model list", e); 184 // Wrap exception to add a filename to it 185 throw new IOException("Error while parsing " + file, e); 186 } 187 188 } 189 } 190 } 191