1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.nn.benchmark.core;
18 
19 import android.os.Bundle;
20 import android.os.Parcel;
21 import android.os.Parcelable;
22 import android.text.TextUtils;
23 import android.util.Pair;
24 
25 import java.util.ArrayList;
26 import java.util.Arrays;
27 import java.util.List;
28 
29 public class BenchmarkResult implements Parcelable {
30     public final static String BACKEND_TFLITE_NNAPI = "TFLite_NNAPI";
31     public final static String BACKEND_TFLITE_CPU = "TFLite_CPU";
32 
33     private final static int TIME_FREQ_ARRAY_SIZE = 32;
34 
35     /** The name of the benchmark */
36     private String mTestInfo;
37 
38     /** Latency results */
39     private LatencyResult mLatencyInference;
40     private LatencyResult mLatencyCompileWithoutCache;
41     private LatencyResult mLatencySaveToCache;
42     private LatencyResult mLatencyPrepareFromCache;
43 
44     /** Accuracy results */
45     private float mSumOfMSEs;
46     private float mMaxSingleError;
47     private int mNumberOfEvaluatorResults;
48     private String[] mEvaluatorKeys = {};
49     private float[] mEvaluatorResults = {};
50 
51     /** Type of backend used for inference */
52     private String mBackendType;
53 
54     /** Size of test set using for inference */
55     private int mTestSetSize;
56 
57     /** Size of compilation cache files in bytes */
58     private int mCompilationCacheSizeBytes = 0;
59 
60     /** List of validation errors */
61     private String[] mValidationErrors = {};
62 
63     /** Error that prevents the benchmark from running, e.g. SDK version not supported. */
64     private String mBenchmarkError;
65 
BenchmarkResult(LatencyResult inferenceLatency, float sumOfMSEs, float maxSingleError, String testInfo, String[] evaluatorKeys, float[] evaluatorResults, String backendType, int testSetSize, String[] validationErrors)66     public BenchmarkResult(LatencyResult inferenceLatency,
67             float sumOfMSEs, float maxSingleError, String testInfo,
68             String[] evaluatorKeys, float[] evaluatorResults,
69             String backendType, int testSetSize, String[] validationErrors) {
70         mLatencyInference = inferenceLatency;
71         mSumOfMSEs = sumOfMSEs;
72         mMaxSingleError = maxSingleError;
73         mTestInfo = testInfo;
74         mBackendType = backendType;
75         mTestSetSize = testSetSize;
76         if (validationErrors == null) {
77             mValidationErrors = new String[0];
78         } else {
79             mValidationErrors = validationErrors;
80         }
81 
82         if (evaluatorKeys == null) {
83             mEvaluatorKeys = new String[0];
84         } else {
85             mEvaluatorKeys = evaluatorKeys;
86         }
87         if (evaluatorResults == null) {
88             mEvaluatorResults = new float[0];
89         } else {
90             mEvaluatorResults = evaluatorResults;
91         }
92         if (mEvaluatorResults.length != mEvaluatorKeys.length) {
93             throw new IllegalArgumentException("Different number of evaluator keys vs values");
94         }
95         mNumberOfEvaluatorResults = mEvaluatorResults.length;
96     }
97 
BenchmarkResult(String benchmarkError)98     public BenchmarkResult(String benchmarkError) {
99         mBenchmarkError = benchmarkError;
100     }
101 
hasValidationErrors()102     public boolean hasValidationErrors() {
103         return mValidationErrors.length > 0;
104     }
105 
BenchmarkResult(Parcel in)106     protected BenchmarkResult(Parcel in) {
107         mLatencyInference = in.readParcelable(LatencyResult.class.getClassLoader());
108         mLatencyCompileWithoutCache = in.readParcelable(LatencyResult.class.getClassLoader());
109         mLatencySaveToCache = in.readParcelable(LatencyResult.class.getClassLoader());
110         mLatencyPrepareFromCache = in.readParcelable(LatencyResult.class.getClassLoader());
111         mSumOfMSEs = in.readFloat();
112         mMaxSingleError = in.readFloat();
113         mTestInfo = in.readString();
114         mNumberOfEvaluatorResults = in.readInt();
115         mEvaluatorKeys = new String[mNumberOfEvaluatorResults];
116         in.readStringArray(mEvaluatorKeys);
117         mEvaluatorResults = new float[mNumberOfEvaluatorResults];
118         in.readFloatArray(mEvaluatorResults);
119         if (mEvaluatorResults.length != mEvaluatorKeys.length) {
120             throw new IllegalArgumentException("Different number of evaluator keys vs values");
121         }
122         mBackendType = in.readString();
123         mTestSetSize = in.readInt();
124         mCompilationCacheSizeBytes = in.readInt();
125         int validationsErrorsSize = in.readInt();
126         mValidationErrors = new String[validationsErrorsSize];
127         in.readStringArray(mValidationErrors);
128         mBenchmarkError = in.readString();
129     }
130 
131     @Override
describeContents()132     public int describeContents() {
133         return 0;
134     }
135 
136     @Override
writeToParcel(Parcel dest, int flags)137     public void writeToParcel(Parcel dest, int flags) {
138         dest.writeParcelable(mLatencyInference, flags);
139         dest.writeParcelable(mLatencyCompileWithoutCache, flags);
140         dest.writeParcelable(mLatencySaveToCache, flags);
141         dest.writeParcelable(mLatencyPrepareFromCache, flags);
142         dest.writeFloat(mSumOfMSEs);
143         dest.writeFloat(mMaxSingleError);
144         dest.writeString(mTestInfo);
145         dest.writeInt(mNumberOfEvaluatorResults);
146         dest.writeStringArray(mEvaluatorKeys);
147         dest.writeFloatArray(mEvaluatorResults);
148         dest.writeString(mBackendType);
149         dest.writeInt(mTestSetSize);
150         dest.writeInt(mCompilationCacheSizeBytes);
151         dest.writeInt(mValidationErrors.length);
152         dest.writeStringArray(mValidationErrors);
153         dest.writeString(mBenchmarkError);
154     }
155 
156     @SuppressWarnings("unused")
157     public static final Parcelable.Creator<BenchmarkResult> CREATOR =
158             new Parcelable.Creator<BenchmarkResult>() {
159                 @Override
160                 public BenchmarkResult createFromParcel(Parcel in) {
161                     return new BenchmarkResult(in);
162                 }
163 
164                 @Override
165                 public BenchmarkResult[] newArray(int size) {
166                     return new BenchmarkResult[size];
167                 }
168             };
169 
getError()170     public float getError() {
171         return mSumOfMSEs;
172     }
173 
getMeanTimeSec()174     public float getMeanTimeSec() {
175         return mLatencyInference.getMeanTimeSec();
176     }
177 
getEvaluatorResults()178     public List<Pair<String, Float>> getEvaluatorResults() {
179         List<Pair<String, Float>> results = new ArrayList<>();
180         for (int i = 0; i < mEvaluatorKeys.length; ++i) {
181             results.add(new Pair<>(mEvaluatorKeys[i], mEvaluatorResults[i]));
182         }
183         return results;
184     }
185 
186     @Override
toString()187     public String toString() {
188         if (!TextUtils.isEmpty(mBenchmarkError)) {
189             return mBenchmarkError;
190         }
191 
192         StringBuilder result = new StringBuilder("BenchmarkResult{" +
193                 "mTestInfo='" + mTestInfo + '\'' +
194                 ", mLatencyInference=" + mLatencyInference.toString() +
195                 ", mSumOfMSEs=" + mSumOfMSEs +
196                 ", mMaxSingleErrors=" + mMaxSingleError +
197                 ", mBackendType=" + mBackendType +
198                 ", mTestSetSize=" + mTestSetSize);
199         for (int i = 0; i < mEvaluatorKeys.length; i++) {
200             result.append(", ").append(mEvaluatorKeys[i]).append("=").append(mEvaluatorResults[i]);
201         }
202 
203         result.append(", mValidationErrors=[");
204         for (int i = 0; i < mValidationErrors.length; i++) {
205             result.append(mValidationErrors[i]);
206             if (i < mValidationErrors.length - 1) {
207                 result.append(",");
208             }
209         }
210         result.append("]");
211 
212         if (mLatencyCompileWithoutCache != null) {
213             result.append(", mLatencyCompileWithoutCache=")
214                     .append(mLatencyCompileWithoutCache.toString());
215         }
216         if (mLatencySaveToCache != null) {
217             result.append(", mLatencySaveToCache=").append(mLatencySaveToCache.toString());
218         }
219         if (mLatencyPrepareFromCache != null) {
220             result.append(", mLatencyPrepareFromCache=")
221                     .append(mLatencyPrepareFromCache.toString());
222         }
223         result.append(", mCompilationCacheSizeBytes=").append(mCompilationCacheSizeBytes);
224 
225         result.append('}');
226         return result.toString();
227     }
228 
hasBenchmarkError()229     public boolean hasBenchmarkError() {
230         return !TextUtils.isEmpty(mBenchmarkError);
231     }
232 
getBenchmarkError()233     public String getBenchmarkError() {
234         if (!hasBenchmarkError()) return null;
235 
236         return mBenchmarkError;
237     }
238 
setBenchmarkError(String benchmarkError)239     public void setBenchmarkError(String benchmarkError) {
240         mBenchmarkError = benchmarkError;
241     }
242 
getSummary(float baselineSec)243     public String getSummary(float baselineSec) {
244         if (hasBenchmarkError()) {
245             return getBenchmarkError();
246         }
247         return mLatencyInference.getSummary(baselineSec);
248     }
249 
toBundle(String testName)250     public Bundle toBundle(String testName) {
251         Bundle results = new Bundle();
252         if (!TextUtils.isEmpty(mBenchmarkError)) {
253             results.putString(testName + "_error", mBenchmarkError);
254             return results;
255         }
256 
257         mLatencyInference.putToBundle(results, testName + "_inference");
258         results.putFloat(testName + "_inference_mean_square_error",
259                 mSumOfMSEs / mLatencyInference.getIterations());
260         results.putFloat(testName + "_inference_max_single_error", mMaxSingleError);
261         for (int i = 0; i < mEvaluatorKeys.length; i++) {
262             results.putFloat(testName + "_inference_" + mEvaluatorKeys[i], mEvaluatorResults[i]);
263         }
264         if (mLatencyCompileWithoutCache != null) {
265             mLatencyCompileWithoutCache.putToBundle(results, testName + "_compile_without_cache");
266         }
267         if (mLatencySaveToCache != null) {
268             mLatencySaveToCache.putToBundle(results, testName + "_save_to_cache");
269         }
270         if (mLatencyPrepareFromCache != null) {
271             mLatencyPrepareFromCache.putToBundle(results, testName + "_prepare_from_cache");
272         }
273         if (mCompilationCacheSizeBytes > 0) {
274             results.putInt(testName + "_compilation_cache_size", mCompilationCacheSizeBytes);
275         }
276         return results;
277     }
278 
279     @SuppressWarnings("AndroidJdkLibsChecker")
toCsvLine()280     public String toCsvLine() {
281         if (!TextUtils.isEmpty(mBenchmarkError)) {
282             return "";
283         }
284 
285         StringBuilder sb = new StringBuilder();
286         sb.append(mTestInfo).append(',').append(mBackendType);
287 
288         mLatencyInference.appendToCsvLine(sb);
289 
290         sb.append(',').append(String.join(",",
291             String.valueOf(mMaxSingleError),
292             String.valueOf(mTestSetSize),
293             String.valueOf(mEvaluatorKeys.length),
294             String.valueOf(mValidationErrors.length)));
295 
296         for (int i = 0; i < mEvaluatorKeys.length; ++i) {
297             sb.append(',').append(mEvaluatorKeys[i]);
298         }
299 
300         for (int i = 0; i < mEvaluatorKeys.length; ++i) {
301             sb.append(',').append(mEvaluatorResults[i]);
302         }
303 
304         for (String validationError : mValidationErrors) {
305             sb.append(',').append(validationError.replace(',', ' '));
306         }
307 
308         sb.append(',').append(mLatencyCompileWithoutCache != null);
309         if (mLatencyCompileWithoutCache != null) {
310             mLatencyCompileWithoutCache.appendToCsvLine(sb);
311         }
312         sb.append(',').append(mLatencySaveToCache != null);
313         if (mLatencySaveToCache != null) {
314             mLatencySaveToCache.appendToCsvLine(sb);
315         }
316         sb.append(',').append(mLatencyPrepareFromCache != null);
317         if (mLatencyPrepareFromCache != null) {
318             mLatencyPrepareFromCache.appendToCsvLine(sb);
319         }
320         sb.append(',').append(mCompilationCacheSizeBytes);
321 
322         sb.append('\n');
323         return sb.toString();
324     }
325 
fromInferenceResults( String testInfo, String backendType, List<InferenceInOutSequence> inferenceInOuts, List<InferenceResult> inferenceResults, EvaluatorInterface evaluator)326     public static BenchmarkResult fromInferenceResults(
327             String testInfo,
328             String backendType,
329             List<InferenceInOutSequence> inferenceInOuts,
330             List<InferenceResult> inferenceResults,
331             EvaluatorInterface evaluator) {
332         float[] latencies = new float[inferenceResults.size()];
333         float sumOfMSEs = 0;
334         float maxSingleError = 0;
335         for (int i = 0; i < inferenceResults.size(); i++) {
336             InferenceResult iresult = inferenceResults.get(i);
337             latencies[i] = iresult.mComputeTimeSec;
338             if (iresult.mMeanSquaredErrors != null) {
339                 for (float mse : iresult.mMeanSquaredErrors) {
340                     sumOfMSEs += mse;
341                 }
342             }
343             if (iresult.mMaxSingleErrors != null) {
344                 for (float mse : iresult.mMaxSingleErrors) {
345                     if (mse > maxSingleError) {
346                         maxSingleError = mse;
347                     }
348                 }
349             }
350         }
351 
352         String[] evaluatorKeys = null;
353         float[] evaluatorResults = null;
354         String[] validationErrors = null;
355         if (evaluator != null) {
356             ArrayList<String> keys = new ArrayList<String>();
357             ArrayList<Float> results = new ArrayList<Float>();
358             ArrayList<String> validationErrorsList = new ArrayList<>();
359             evaluator.EvaluateAccuracy(inferenceInOuts, inferenceResults, keys, results,
360                     validationErrorsList);
361             evaluatorKeys = new String[keys.size()];
362             evaluatorKeys = keys.toArray(evaluatorKeys);
363             evaluatorResults = new float[results.size()];
364             for (int i = 0; i < evaluatorResults.length; i++) {
365                 evaluatorResults[i] = results.get(i).floatValue();
366             }
367             validationErrors = new String[validationErrorsList.size()];
368             validationErrorsList.toArray(validationErrors);
369         }
370 
371         // Calc test set size
372         int testSetSize = 0;
373         for (InferenceInOutSequence iios : inferenceInOuts) {
374             testSetSize += iios.size();
375         }
376 
377         return new BenchmarkResult(new LatencyResult(latencies), sumOfMSEs, maxSingleError,
378                 testInfo, evaluatorKeys, evaluatorResults, backendType, testSetSize,
379                 validationErrors);
380     }
381 
setCompilationBenchmarkResult(CompilationBenchmarkResult result)382     public void setCompilationBenchmarkResult(CompilationBenchmarkResult result) {
383         mLatencyCompileWithoutCache = new LatencyResult(result.mCompileWithoutCacheTimeSec);
384         if (result.mSaveToCacheTimeSec != null) {
385             mLatencySaveToCache = new LatencyResult(result.mSaveToCacheTimeSec);
386         }
387         if (result.mPrepareFromCacheTimeSec != null) {
388             mLatencyPrepareFromCache = new LatencyResult(result.mPrepareFromCacheTimeSec);
389         }
390         mCompilationCacheSizeBytes = result.mCacheSizeBytes;
391     }
392 }
393