1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.nn.benchmark.core; 18 19 import java.util.ArrayList; 20 import java.util.List; 21 import java.util.concurrent.atomic.AtomicReference; 22 23 /** Information about available benchmarking models */ 24 public class TestModels { 25 /** Entry for a single benchmarking model */ 26 public static class TestModelEntry { 27 /** Unique model name, used to find benchmark data */ 28 public final String mModelName; 29 30 /** Expected inference performance in seconds */ 31 public final float mBaselineSec; 32 33 /** Shape of input data */ 34 public final int[] mInputShape; 35 36 /** File pair asset input/output pairs */ 37 public final InferenceInOutSequence.FromAssets[] mInOutAssets; 38 39 /** Dataset inputs */ 40 public final InferenceInOutSequence.FromDataset[] mInOutDatasets; 41 42 /** Readable name for test output */ 43 public final String mTestName; 44 45 /** Name of model file, so that the same file can be reused */ 46 public final String mModelFile; 47 48 /** The evaluator to use for validating the results. */ 49 public final EvaluatorConfig mEvaluator; 50 51 /** Min SDK version that the model can run on. */ 52 public final int mMinSdkVersion; 53 54 /* Number of bytes per input data entry */ 55 public final int mInDataSize; 56 TestModelEntry(String modelName, float baselineSec, int[] inputShape, InferenceInOutSequence.FromAssets[] inOutAssets, InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, String modelFile, EvaluatorConfig evaluator, int minSdkVersion, int inDataSize)57 public TestModelEntry(String modelName, float baselineSec, int[] inputShape, 58 InferenceInOutSequence.FromAssets[] inOutAssets, 59 InferenceInOutSequence.FromDataset[] inOutDatasets, String testName, 60 String modelFile, 61 EvaluatorConfig evaluator, int minSdkVersion, int inDataSize) { 62 mModelName = modelName; 63 mBaselineSec = baselineSec; 64 mInputShape = inputShape; 65 mInOutAssets = inOutAssets; 66 mInOutDatasets = inOutDatasets; 67 mTestName = testName; 68 mModelFile = modelFile; 69 mEvaluator = evaluator; 70 mMinSdkVersion = minSdkVersion; 71 mInDataSize = inDataSize; 72 } 73 createNNTestBase()74 public NNTestBase createNNTestBase() { 75 return new NNTestBase(mModelName, mModelFile, mInputShape, mInOutAssets, mInOutDatasets, 76 mEvaluator, mMinSdkVersion); 77 } 78 createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump)79 public NNTestBase createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump) { 80 return createNNTestBase(useNNApi, enableIntermediateTensorsDump, /*mmapModel=*/false); 81 } 82 createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump, boolean mmapModel)83 public NNTestBase createNNTestBase(boolean useNNApi, boolean enableIntermediateTensorsDump, 84 boolean mmapModel) { 85 NNTestBase test = createNNTestBase(); 86 test.useNNApi(useNNApi); 87 test.enableIntermediateTensorsDump(enableIntermediateTensorsDump); 88 test.setMmapModel(mmapModel); 89 return test; 90 } 91 toString()92 public String toString() { 93 return mModelName; 94 } 95 getTestName()96 public String getTestName() { 97 return mTestName; 98 } 99 100 withDisabledEvaluation()101 public TestModelEntry withDisabledEvaluation() { 102 return new TestModelEntry(mModelName, mBaselineSec, mInputShape, mInOutAssets, 103 mInOutDatasets, mTestName, mModelFile, 104 null, // Disable evaluation. 105 mMinSdkVersion, mInDataSize); 106 } 107 } 108 109 static private final List<TestModelEntry> sTestModelEntryList = new ArrayList<>(); 110 static private final AtomicReference<List<TestModelEntry>> frozenEntries = 111 new AtomicReference<>(null); 112 113 114 /** Add new benchmark model. */ registerModel(TestModelEntry model)115 static public void registerModel(TestModelEntry model) { 116 if (frozenEntries.get() != null) { 117 throw new IllegalStateException("Can't register new models after its list is frozen"); 118 } 119 sTestModelEntryList.add(model); 120 } 121 isListFrozen()122 public static boolean isListFrozen() { 123 return frozenEntries.get() != null; 124 } 125 126 /** 127 * Fetch list of test models. 128 * 129 * If this method was called at least once, then it's impossible to register new models. 130 */ modelsList()131 static public List<TestModelEntry> modelsList() { 132 frozenEntries.compareAndSet(null, sTestModelEntryList); 133 return frozenEntries.get(); 134 } 135 136 /** Fetch model by its name. */ getModelByName(String name)137 static public TestModelEntry getModelByName(String name) { 138 for (TestModelEntry testModelEntry : modelsList()) { 139 if (testModelEntry.mModelName.equals(name)) { 140 return testModelEntry; 141 } 142 } 143 throw new IllegalArgumentException("Unknown TestModelEntry named " + name); 144 } 145 146 } 147