1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.os.Bundle; 20 import android.os.Parcel; 21 import android.os.Parcelable; 22 import android.text.TextUtils; 23 import android.util.Pair; 24 25 import java.util.ArrayList; 26 import java.util.Arrays; 27 import java.util.List; 28 29 public class BenchmarkResult implements Parcelable { 30 public final static String BACKEND_TFLITE_NNAPI = "TFLite_NNAPI"; 31 public final static String BACKEND_TFLITE_CPU = "TFLite_CPU"; 32 33 private final static int TIME_FREQ_ARRAY_SIZE = 32; 34 35 /** The name of the benchmark */ 36 private String mTestInfo; 37 38 /** Latency results */ 39 private LatencyResult mLatencyInference; 40 private LatencyResult mLatencyCompileWithoutCache; 41 private LatencyResult mLatencySaveToCache; 42 private LatencyResult mLatencyPrepareFromCache; 43 44 /** Accuracy results */ 45 private float mSumOfMSEs; 46 private float mMaxSingleError; 47 private int mNumberOfEvaluatorResults; 48 private String[] mEvaluatorKeys = {}; 49 private float[] mEvaluatorResults = {}; 50 51 /** Type of backend used for inference */ 52 private String mBackendType; 53 54 /** Size of test set using for inference */ 55 private int mTestSetSize; 56 57 /** Size of compilation cache files in bytes */ 58 private int mCompilationCacheSizeBytes = 0; 59 60 /** List of validation errors */ 61 private String[] mValidationErrors = {}; 62 63 /** Error that prevents the benchmark from running, e.g. SDK version not supported. */ 64 private String mBenchmarkError; 65 BenchmarkResult(LatencyResult inferenceLatency, float sumOfMSEs, float maxSingleError, String testInfo, String[] evaluatorKeys, float[] evaluatorResults, String backendType, int testSetSize, String[] validationErrors)66 public BenchmarkResult(LatencyResult inferenceLatency, 67 float sumOfMSEs, float maxSingleError, String testInfo, 68 String[] evaluatorKeys, float[] evaluatorResults, 69 String backendType, int testSetSize, String[] validationErrors) { 70 mLatencyInference = inferenceLatency; 71 mSumOfMSEs = sumOfMSEs; 72 mMaxSingleError = maxSingleError; 73 mTestInfo = testInfo; 74 mBackendType = backendType; 75 mTestSetSize = testSetSize; 76 if (validationErrors == null) { 77 mValidationErrors = new String[0]; 78 } else { 79 mValidationErrors = validationErrors; 80 } 81 82 if (evaluatorKeys == null) { 83 mEvaluatorKeys = new String[0]; 84 } else { 85 mEvaluatorKeys = evaluatorKeys; 86 } 87 if (evaluatorResults == null) { 88 mEvaluatorResults = new float[0]; 89 } else { 90 mEvaluatorResults = evaluatorResults; 91 } 92 if (mEvaluatorResults.length != mEvaluatorKeys.length) { 93 throw new IllegalArgumentException("Different number of evaluator keys vs values"); 94 } 95 mNumberOfEvaluatorResults = mEvaluatorResults.length; 96 } 97 BenchmarkResult(String benchmarkError)98 public BenchmarkResult(String benchmarkError) { 99 mBenchmarkError = benchmarkError; 100 } 101 hasValidationErrors()102 public boolean hasValidationErrors() { 103 return mValidationErrors.length > 0; 104 } 105 BenchmarkResult(Parcel in)106 protected BenchmarkResult(Parcel in) { 107 mLatencyInference = in.readParcelable(LatencyResult.class.getClassLoader()); 108 mLatencyCompileWithoutCache = in.readParcelable(LatencyResult.class.getClassLoader()); 109 mLatencySaveToCache = in.readParcelable(LatencyResult.class.getClassLoader()); 110 mLatencyPrepareFromCache = in.readParcelable(LatencyResult.class.getClassLoader()); 111 mSumOfMSEs = in.readFloat(); 112 mMaxSingleError = in.readFloat(); 113 mTestInfo = in.readString(); 114 mNumberOfEvaluatorResults = in.readInt(); 115 mEvaluatorKeys = new String[mNumberOfEvaluatorResults]; 116 in.readStringArray(mEvaluatorKeys); 117 mEvaluatorResults = new float[mNumberOfEvaluatorResults]; 118 in.readFloatArray(mEvaluatorResults); 119 if (mEvaluatorResults.length != mEvaluatorKeys.length) { 120 throw new IllegalArgumentException("Different number of evaluator keys vs values"); 121 } 122 mBackendType = in.readString(); 123 mTestSetSize = in.readInt(); 124 mCompilationCacheSizeBytes = in.readInt(); 125 int validationsErrorsSize = in.readInt(); 126 mValidationErrors = new String[validationsErrorsSize]; 127 in.readStringArray(mValidationErrors); 128 mBenchmarkError = in.readString(); 129 } 130 131 @Override describeContents()132 public int describeContents() { 133 return 0; 134 } 135 136 @Override writeToParcel(Parcel dest, int flags)137 public void writeToParcel(Parcel dest, int flags) { 138 dest.writeParcelable(mLatencyInference, flags); 139 dest.writeParcelable(mLatencyCompileWithoutCache, flags); 140 dest.writeParcelable(mLatencySaveToCache, flags); 141 dest.writeParcelable(mLatencyPrepareFromCache, flags); 142 dest.writeFloat(mSumOfMSEs); 143 dest.writeFloat(mMaxSingleError); 144 dest.writeString(mTestInfo); 145 dest.writeInt(mNumberOfEvaluatorResults); 146 dest.writeStringArray(mEvaluatorKeys); 147 dest.writeFloatArray(mEvaluatorResults); 148 dest.writeString(mBackendType); 149 dest.writeInt(mTestSetSize); 150 dest.writeInt(mCompilationCacheSizeBytes); 151 dest.writeInt(mValidationErrors.length); 152 dest.writeStringArray(mValidationErrors); 153 dest.writeString(mBenchmarkError); 154 } 155 156 @SuppressWarnings("unused") 157 public static final Parcelable.Creator<BenchmarkResult> CREATOR = 158 new Parcelable.Creator<BenchmarkResult>() { 159 @Override 160 public BenchmarkResult createFromParcel(Parcel in) { 161 return new BenchmarkResult(in); 162 } 163 164 @Override 165 public BenchmarkResult[] newArray(int size) { 166 return new BenchmarkResult[size]; 167 } 168 }; 169 getError()170 public float getError() { 171 return mSumOfMSEs; 172 } 173 getMeanTimeSec()174 public float getMeanTimeSec() { 175 return mLatencyInference.getMeanTimeSec(); 176 } 177 getEvaluatorResults()178 public List<Pair<String, Float>> getEvaluatorResults() { 179 List<Pair<String, Float>> results = new ArrayList<>(); 180 for (int i = 0; i < mEvaluatorKeys.length; ++i) { 181 results.add(new Pair<>(mEvaluatorKeys[i], mEvaluatorResults[i])); 182 } 183 return results; 184 } 185 186 @Override toString()187 public String toString() { 188 if (!TextUtils.isEmpty(mBenchmarkError)) { 189 return mBenchmarkError; 190 } 191 192 StringBuilder result = new StringBuilder("BenchmarkResult{" + 193 "mTestInfo='" + mTestInfo + '\'' + 194 ", mLatencyInference=" + mLatencyInference.toString() + 195 ", mSumOfMSEs=" + mSumOfMSEs + 196 ", mMaxSingleErrors=" + mMaxSingleError + 197 ", mBackendType=" + mBackendType + 198 ", mTestSetSize=" + mTestSetSize); 199 for (int i = 0; i < mEvaluatorKeys.length; i++) { 200 result.append(", ").append(mEvaluatorKeys[i]).append("=").append(mEvaluatorResults[i]); 201 } 202 203 result.append(", mValidationErrors=["); 204 for (int i = 0; i < mValidationErrors.length; i++) { 205 result.append(mValidationErrors[i]); 206 if (i < mValidationErrors.length - 1) { 207 result.append(","); 208 } 209 } 210 result.append("]"); 211 212 if (mLatencyCompileWithoutCache != null) { 213 result.append(", mLatencyCompileWithoutCache=") 214 .append(mLatencyCompileWithoutCache.toString()); 215 } 216 if (mLatencySaveToCache != null) { 217 result.append(", mLatencySaveToCache=").append(mLatencySaveToCache.toString()); 218 } 219 if (mLatencyPrepareFromCache != null) { 220 result.append(", mLatencyPrepareFromCache=") 221 .append(mLatencyPrepareFromCache.toString()); 222 } 223 result.append(", mCompilationCacheSizeBytes=").append(mCompilationCacheSizeBytes); 224 225 result.append('}'); 226 return result.toString(); 227 } 228 hasBenchmarkError()229 public boolean hasBenchmarkError() { 230 return !TextUtils.isEmpty(mBenchmarkError); 231 } 232 getBenchmarkError()233 public String getBenchmarkError() { 234 if (!hasBenchmarkError()) return null; 235 236 return mBenchmarkError; 237 } 238 setBenchmarkError(String benchmarkError)239 public void setBenchmarkError(String benchmarkError) { 240 mBenchmarkError = benchmarkError; 241 } 242 getSummary(float baselineSec)243 public String getSummary(float baselineSec) { 244 if (hasBenchmarkError()) { 245 return getBenchmarkError(); 246 } 247 return mLatencyInference.getSummary(baselineSec); 248 } 249 toBundle(String testName)250 public Bundle toBundle(String testName) { 251 Bundle results = new Bundle(); 252 if (!TextUtils.isEmpty(mBenchmarkError)) { 253 results.putString(testName + "_error", mBenchmarkError); 254 return results; 255 } 256 257 mLatencyInference.putToBundle(results, testName + "_inference"); 258 results.putFloat(testName + "_inference_mean_square_error", 259 mSumOfMSEs / mLatencyInference.getIterations()); 260 results.putFloat(testName + "_inference_max_single_error", mMaxSingleError); 261 for (int i = 0; i < mEvaluatorKeys.length; i++) { 262 results.putFloat(testName + "_inference_" + mEvaluatorKeys[i], mEvaluatorResults[i]); 263 } 264 if (mLatencyCompileWithoutCache != null) { 265 mLatencyCompileWithoutCache.putToBundle(results, testName + "_compile_without_cache"); 266 } 267 if (mLatencySaveToCache != null) { 268 mLatencySaveToCache.putToBundle(results, testName + "_save_to_cache"); 269 } 270 if (mLatencyPrepareFromCache != null) { 271 mLatencyPrepareFromCache.putToBundle(results, testName + "_prepare_from_cache"); 272 } 273 if (mCompilationCacheSizeBytes > 0) { 274 results.putInt(testName + "_compilation_cache_size", mCompilationCacheSizeBytes); 275 } 276 return results; 277 } 278 279 @SuppressWarnings("AndroidJdkLibsChecker") toCsvLine()280 public String toCsvLine() { 281 if (!TextUtils.isEmpty(mBenchmarkError)) { 282 return ""; 283 } 284 285 StringBuilder sb = new StringBuilder(); 286 sb.append(mTestInfo).append(',').append(mBackendType); 287 288 mLatencyInference.appendToCsvLine(sb); 289 290 sb.append(',').append(String.join(",", 291 String.valueOf(mMaxSingleError), 292 String.valueOf(mTestSetSize), 293 String.valueOf(mEvaluatorKeys.length), 294 String.valueOf(mValidationErrors.length))); 295 296 for (int i = 0; i < mEvaluatorKeys.length; ++i) { 297 sb.append(',').append(mEvaluatorKeys[i]); 298 } 299 300 for (int i = 0; i < mEvaluatorKeys.length; ++i) { 301 sb.append(',').append(mEvaluatorResults[i]); 302 } 303 304 for (String validationError : mValidationErrors) { 305 sb.append(',').append(validationError.replace(',', ' ')); 306 } 307 308 sb.append(',').append(mLatencyCompileWithoutCache != null); 309 if (mLatencyCompileWithoutCache != null) { 310 mLatencyCompileWithoutCache.appendToCsvLine(sb); 311 } 312 sb.append(',').append(mLatencySaveToCache != null); 313 if (mLatencySaveToCache != null) { 314 mLatencySaveToCache.appendToCsvLine(sb); 315 } 316 sb.append(',').append(mLatencyPrepareFromCache != null); 317 if (mLatencyPrepareFromCache != null) { 318 mLatencyPrepareFromCache.appendToCsvLine(sb); 319 } 320 sb.append(',').append(mCompilationCacheSizeBytes); 321 322 sb.append('\n'); 323 return sb.toString(); 324 } 325 fromInferenceResults( String testInfo, String backendType, List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, EvaluatorInterface evaluator)326 public static BenchmarkResult fromInferenceResults( 327 String testInfo, 328 String backendType, 329 List<InferenceInOutSequence> inferenceInOuts, 330 List<InferenceResult> inferenceResults, 331 EvaluatorInterface evaluator) { 332 float[] latencies = new float[inferenceResults.size()]; 333 float sumOfMSEs = 0; 334 float maxSingleError = 0; 335 for (int i = 0; i < inferenceResults.size(); i++) { 336 InferenceResult iresult = inferenceResults.get(i); 337 latencies[i] = iresult.mComputeTimeSec; 338 if (iresult.mMeanSquaredErrors != null) { 339 for (float mse : iresult.mMeanSquaredErrors) { 340 sumOfMSEs += mse; 341 } 342 } 343 if (iresult.mMaxSingleErrors != null) { 344 for (float mse : iresult.mMaxSingleErrors) { 345 if (mse > maxSingleError) { 346 maxSingleError = mse; 347 } 348 } 349 } 350 } 351 352 String[] evaluatorKeys = null; 353 float[] evaluatorResults = null; 354 String[] validationErrors = null; 355 if (evaluator != null) { 356 ArrayList<String> keys = new ArrayList<String>(); 357 ArrayList<Float> results = new ArrayList<Float>(); 358 ArrayList<String> validationErrorsList = new ArrayList<>(); 359 evaluator.EvaluateAccuracy(inferenceInOuts, inferenceResults, keys, results, 360 validationErrorsList); 361 evaluatorKeys = new String[keys.size()]; 362 evaluatorKeys = keys.toArray(evaluatorKeys); 363 evaluatorResults = new float[results.size()]; 364 for (int i = 0; i < evaluatorResults.length; i++) { 365 evaluatorResults[i] = results.get(i).floatValue(); 366 } 367 validationErrors = new String[validationErrorsList.size()]; 368 validationErrorsList.toArray(validationErrors); 369 } 370 371 // Calc test set size 372 int testSetSize = 0; 373 for (InferenceInOutSequence iios : inferenceInOuts) { 374 testSetSize += iios.size(); 375 } 376 377 return new BenchmarkResult(new LatencyResult(latencies), sumOfMSEs, maxSingleError, 378 testInfo, evaluatorKeys, evaluatorResults, backendType, testSetSize, 379 validationErrors); 380 } 381 setCompilationBenchmarkResult(CompilationBenchmarkResult result)382 public void setCompilationBenchmarkResult(CompilationBenchmarkResult result) { 383 mLatencyCompileWithoutCache = new LatencyResult(result.mCompileWithoutCacheTimeSec); 384 if (result.mSaveToCacheTimeSec != null) { 385 mLatencySaveToCache = new LatencyResult(result.mSaveToCacheTimeSec); 386 } 387 if (result.mPrepareFromCacheTimeSec != null) { 388 mLatencyPrepareFromCache = new LatencyResult(result.mPrepareFromCacheTimeSec); 389 } 390 mCompilationCacheSizeBytes = result.mCacheSizeBytes; 391 } 392 } 393