1#!/usr/bin/python3 2 3# Copyright 2018, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Example generator 18 19Compiles spec files and generates the corresponding C++ TestModel definitions. 20Invoked by ml/nn/runtime/test/specs/generate_all_tests.sh; 21See that script for details on how this script is used. 22 23""" 24 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28import os 29import sys 30import traceback 31 32import test_generator as tg 33 34# See ToCpp() 35COMMENT_KEY = "__COMMENT__" 36 37# Take a model from command line 38def ParseCmdLine(): 39 parser = tg.ArgumentParser() 40 parser.add_argument("-e", "--example", help="the output example file or directory") 41 args = tg.ParseArgs(parser) 42 tg.FileNames.InitializeFileLists(args.spec, args.example) 43 44# Write headers for generated files, which are boilerplate codes only related to filenames 45def InitializeFiles(example_fd): 46 specFileBase = os.path.basename(tg.FileNames.specFile) 47 fileHeader = """\ 48// Generated from {spec_file} 49// DO NOT EDIT 50// clang-format off 51#include "TestHarness.h" 52using namespace test_helper; 53""" 54 if example_fd is not None: 55 print(fileHeader.format(spec_file=specFileBase), file=example_fd) 56 57def IndentedStr(s, indent): 58 return ("\n" + " " * indent).join(s.split('\n')) 59 60def ToCpp(var, indent=0): 61 """Get the C++-style representation of a Python object. 62 63 For Python dictionary, it will be mapped to C++ struct aggregate initialization: 64 { 65 .key0 = value0, 66 .key1 = value1, 67 ... 68 } 69 70 For Python list, it will be mapped to C++ list initalization: 71 {value0, value1, ...} 72 73 In both cases, value0, value1, ... are stringified by invoking this method recursively. 74 """ 75 if isinstance(var, dict): 76 if not var: 77 return "{}" 78 comment = var.get(COMMENT_KEY) 79 comment = "" if comment is None else " // %s" % comment 80 str_pair = lambda k, v: " .%s = %s" % (k, ToCpp(v, indent + 4)) 81 agg_init = "{%s\n%s\n}" % (comment, 82 ",\n".join(str_pair(k, var[k]) 83 for k in sorted(var.keys()) 84 if k != COMMENT_KEY)) 85 return IndentedStr(agg_init, indent) 86 elif isinstance(var, (list, tuple)): 87 return "{%s}" % (", ".join(ToCpp(i, indent) for i in var)) 88 elif type(var) is bool: 89 return "true" if var else "false" 90 elif type(var) is float: 91 return tg.PrettyPrintAsFloat(var) 92 else: 93 return str(var) 94 95def GetSymmPerChannelQuantParams(extraParams): 96 """Get the dictionary that corresponds to test_helper::TestSymmPerChannelQuantParams.""" 97 if extraParams is None or extraParams.hide: 98 return {} 99 else: 100 return {"scales": extraParams.scales, "channelDim": extraParams.channelDim} 101 102def GetOperandStruct(operand): 103 """Get the dictionary that corresponds to test_helper::TestOperand.""" 104 return { 105 COMMENT_KEY: operand.name, 106 "type": "TestOperandType::" + operand.type.type, 107 "dimensions": operand.type.dimensions, 108 "scale": operand.type.scale, 109 "zeroPoint": operand.type.zeroPoint, 110 "numberOfConsumers": len(operand.outs), 111 "lifetime": "TestOperandLifeTime::" + operand.lifetime, 112 "channelQuant": GetSymmPerChannelQuantParams(operand.type.extraParams), 113 "isIgnored": isinstance(operand, tg.IgnoredOutput), 114 "data": "TestBuffer::createFromVector<{cpp_type}>({data})".format( 115 cpp_type=operand.type.GetCppTypeString(), 116 data=operand.GetListInitialization(), 117 ) 118 } 119 120def GetOperationStruct(operation): 121 """Get the dictionary that corresponds to test_helper::TestOperation.""" 122 return { 123 "type": "TestOperationType::" + operation.optype, 124 "inputs": [op.model_index for op in operation.ins], 125 "outputs": [op.model_index for op in operation.outs], 126 } 127 128def GetSubgraphStruct(subgraph): 129 """Get the dictionary that corresponds to test_helper::TestSubgraph.""" 130 return { 131 COMMENT_KEY: subgraph.name, 132 "operands": [GetOperandStruct(op) for op in subgraph.operands], 133 "operations": [GetOperationStruct(op) for op in subgraph.operations], 134 "inputIndexes": [op.model_index for op in subgraph.GetInputs()], 135 "outputIndexes": [op.model_index for op in subgraph.GetOutputs()], 136 } 137 138def GetModelStruct(example): 139 """Get the dictionary that corresponds to test_helper::TestModel.""" 140 return { 141 "main": GetSubgraphStruct(example.model), 142 "referenced": [GetSubgraphStruct(model) for model in example.model.GetReferencedModels()], 143 "isRelaxed": example.model.isRelaxed, 144 "expectedMultinomialDistributionTolerance": 145 example.expectedMultinomialDistributionTolerance, 146 "expectFailure": example.expectFailure, 147 "minSupportedVersion": "TestHalVersion::%s" % ( 148 example.model.version if example.model.version is not None else "UNKNOWN"), 149 } 150 151def DumpExample(example, example_fd): 152 assert example.model.compiled 153 template = """\ 154namespace generated_tests::{spec_name} {{ 155 156const TestModel& get_{example_name}() {{ 157 static TestModel model = {aggregate_init}; 158 return model; 159}} 160 161const auto dummy_{example_name} = TestModelManager::get().add("{test_name}", get_{example_name}()); 162 163}} // namespace generated_tests::{spec_name} 164""" 165 print(template.format( 166 spec_name=tg.FileNames.specName, 167 test_name=str(example.testName), 168 example_name=str(example.examplesName), 169 aggregate_init=ToCpp(GetModelStruct(example), indent=4), 170 ), file=example_fd) 171 172 173if __name__ == '__main__': 174 ParseCmdLine() 175 tg.Run(InitializeFiles=InitializeFiles, DumpExample=DumpExample) 176