1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import android.os.Bundle; 20 import android.os.Parcel; 21 import android.os.Parcelable; 22 23 public class LatencyResult implements Parcelable { 24 private final static int TIME_FREQ_ARRAY_SIZE = 32; 25 26 private float mTotalTimeSec; 27 private int mIterations; 28 private float mTimeStdDeviation; 29 30 /** Time offset for inference frequency counts */ 31 private float mTimeFreqStartSec; 32 33 /** Index time offset for inference frequency counts */ 34 private float mTimeFreqStepSec; 35 36 /** 37 * Array of inference frequency counts. 38 * Each entry contains inference count for time range: 39 * [mTimeFreqStartSec + i*mTimeFreqStepSec, mTimeFreqStartSec + (1+i*mTimeFreqStepSec) 40 */ 41 private float[] mTimeFreqSec = {}; 42 LatencyResult(float[] results)43 public LatencyResult(float[] results) { 44 mIterations = results.length; 45 mTotalTimeSec = 0.0f; 46 float maxComputeTimeSec = 0.0f; 47 float minComputeTimeSec = Float.MAX_VALUE; 48 for (float result : results) { 49 mTotalTimeSec += result; 50 maxComputeTimeSec = Math.max(maxComputeTimeSec, result); 51 minComputeTimeSec = Math.min(minComputeTimeSec, result); 52 } 53 54 // Calculate standard deviation. 55 float latencyMean = (mTotalTimeSec / mIterations); 56 float variance = 0.0f; 57 for (float result : results) { 58 float v = (result - latencyMean); 59 variance += v * v; 60 } 61 variance /= mIterations; 62 mTimeStdDeviation = (float) Math.sqrt(variance); 63 64 // Calculate inference frequency/histogram across TIME_FREQ_ARRAY_SIZE buckets. 65 mTimeFreqStartSec = minComputeTimeSec; 66 mTimeFreqStepSec = (maxComputeTimeSec - minComputeTimeSec) / (TIME_FREQ_ARRAY_SIZE - 1); 67 mTimeFreqSec = new float[TIME_FREQ_ARRAY_SIZE]; 68 for (float result : results) { 69 int bucketIndex = (int) ((result - minComputeTimeSec) / mTimeFreqStepSec); 70 mTimeFreqSec[bucketIndex] += 1; 71 } 72 } 73 LatencyResult(Parcel in)74 public LatencyResult(Parcel in) { 75 mTotalTimeSec = in.readFloat(); 76 mIterations = in.readInt(); 77 mTimeStdDeviation = in.readFloat(); 78 mTimeFreqStartSec = in.readFloat(); 79 mTimeFreqStepSec = in.readFloat(); 80 int timeFreqSecLength = in.readInt(); 81 mTimeFreqSec = new float[timeFreqSecLength]; 82 in.readFloatArray(mTimeFreqSec); 83 } 84 85 @Override describeContents()86 public int describeContents() { 87 return 0; 88 } 89 90 @Override writeToParcel(Parcel dest, int flags)91 public void writeToParcel(Parcel dest, int flags) { 92 dest.writeFloat(mTotalTimeSec); 93 dest.writeInt(mIterations); 94 dest.writeFloat(mTimeStdDeviation); 95 dest.writeFloat(mTimeFreqStartSec); 96 dest.writeFloat(mTimeFreqStepSec); 97 dest.writeInt(mTimeFreqSec.length); 98 dest.writeFloatArray(mTimeFreqSec); 99 } 100 putToBundle(Bundle results, String prefix)101 public void putToBundle(Bundle results, String prefix) { 102 // Reported in ms 103 results.putFloat(prefix + "_avg", getMeanTimeSec() * 1000.0f); 104 results.putFloat(prefix + "_std_dev", mTimeStdDeviation * 1000.0f); 105 results.putFloat(prefix + "_total_time", mTotalTimeSec * 1000.0f); 106 results.putInt(prefix + "_iterations", mIterations); 107 } 108 109 @Override toString()110 public String toString() { 111 return "LatencyResult{" 112 + "getMeanTimeSec()=" + getMeanTimeSec() 113 + ", mTotalTimeSec=" + mTotalTimeSec 114 + ", mIterations=" + mIterations 115 + ", mTimeStdDeviation=" + mTimeStdDeviation 116 + ", mTimeFreqStartSec=" + mTimeFreqStartSec 117 + ", mTimeFreqStepSec=" + mTimeFreqStepSec + "}"; 118 } 119 getIterations()120 public int getIterations() { return mIterations; } 121 getMeanTimeSec()122 public float getMeanTimeSec() { return mTotalTimeSec / mIterations; } 123 rebase(float v, float baselineSec)124 private float rebase(float v, float baselineSec) { 125 if (v > 0.001) { 126 v = baselineSec / v; 127 } 128 return v; 129 } 130 getSummary(float baselineSec)131 public String getSummary(float baselineSec) { 132 java.text.DecimalFormat df = new java.text.DecimalFormat("######.##"); 133 return df.format(rebase(getMeanTimeSec(), baselineSec)) + "X, n=" + mIterations 134 + ", μ=" + df.format(getMeanTimeSec() * 1000.0) 135 + "ms, σ=" + df.format(mTimeStdDeviation * 1000.0) + "ms"; 136 } 137 appendToCsvLine(StringBuilder sb)138 public void appendToCsvLine(StringBuilder sb) { 139 sb.append(',').append(String.join(",", 140 String.valueOf(mIterations), 141 String.valueOf(mTotalTimeSec), 142 String.valueOf(mTimeFreqStartSec), 143 String.valueOf(mTimeFreqStepSec), 144 String.valueOf(mTimeFreqSec.length))); 145 146 for (float value : mTimeFreqSec) { 147 sb.append(',').append(value); 148 } 149 } 150 } 151