1#!/usr/bin/python3
2
3# Copyright 2018, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Example generator
18
19Compiles spec files and generates the corresponding C++ TestModel definitions.
20Invoked by ml/nn/runtime/test/specs/generate_all_tests.sh;
21See that script for details on how this script is used.
22
23"""
24
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28import os
29import sys
30import traceback
31
32import test_generator as tg
33
34# See ToCpp()
35COMMENT_KEY = "__COMMENT__"
36
37# Take a model from command line
38def ParseCmdLine():
39    parser = tg.ArgumentParser()
40    parser.add_argument("-e", "--example", help="the output example file or directory")
41    args = tg.ParseArgs(parser)
42    tg.FileNames.InitializeFileLists(args.spec, args.example)
43
44# Write headers for generated files, which are boilerplate codes only related to filenames
45def InitializeFiles(example_fd):
46    specFileBase = os.path.basename(tg.FileNames.specFile)
47    fileHeader = """\
48// Generated from {spec_file}
49// DO NOT EDIT
50// clang-format off
51#include "TestHarness.h"
52using namespace test_helper;
53"""
54    if example_fd is not None:
55        print(fileHeader.format(spec_file=specFileBase), file=example_fd)
56
57def IndentedStr(s, indent):
58    return ("\n" + " " * indent).join(s.split('\n'))
59
60def ToCpp(var, indent=0):
61    """Get the C++-style representation of a Python object.
62
63    For Python dictionary, it will be mapped to C++ struct aggregate initialization:
64        {
65            .key0 = value0,
66            .key1 = value1,
67            ...
68        }
69
70    For Python list, it will be mapped to C++ list initalization:
71        {value0, value1, ...}
72
73    In both cases, value0, value1, ... are stringified by invoking this method recursively.
74    """
75    if isinstance(var, dict):
76        if not var:
77            return "{}"
78        comment = var.get(COMMENT_KEY)
79        comment = "" if comment is None else " // %s" % comment
80        str_pair = lambda k, v: "    .%s = %s" % (k, ToCpp(v, indent + 4))
81        agg_init = "{%s\n%s\n}" % (comment,
82                                   ",\n".join(str_pair(k, var[k])
83                                              for k in sorted(var.keys())
84                                              if k != COMMENT_KEY))
85        return IndentedStr(agg_init, indent)
86    elif isinstance(var, (list, tuple)):
87        return "{%s}" % (", ".join(ToCpp(i, indent) for i in var))
88    elif type(var) is bool:
89        return "true" if var else "false"
90    elif type(var) is float:
91        return tg.PrettyPrintAsFloat(var)
92    else:
93        return str(var)
94
95def GetSymmPerChannelQuantParams(extraParams):
96    """Get the dictionary that corresponds to test_helper::TestSymmPerChannelQuantParams."""
97    if extraParams is None or extraParams.hide:
98        return {}
99    else:
100        return {"scales": extraParams.scales, "channelDim": extraParams.channelDim}
101
102def GetOperandStruct(operand):
103    """Get the dictionary that corresponds to test_helper::TestOperand."""
104    return {
105        COMMENT_KEY: operand.name,
106        "type": "TestOperandType::" + operand.type.type,
107        "dimensions": operand.type.dimensions,
108        "scale": operand.type.scale,
109        "zeroPoint": operand.type.zeroPoint,
110        "numberOfConsumers": len(operand.outs),
111        "lifetime": "TestOperandLifeTime::" + operand.lifetime,
112        "channelQuant": GetSymmPerChannelQuantParams(operand.type.extraParams),
113        "isIgnored": isinstance(operand, tg.IgnoredOutput),
114        "data": "TestBuffer::createFromVector<{cpp_type}>({data})".format(
115            cpp_type=operand.type.GetCppTypeString(),
116            data=operand.GetListInitialization(),
117        )
118    }
119
120def GetOperationStruct(operation):
121    """Get the dictionary that corresponds to test_helper::TestOperation."""
122    return {
123        "type": "TestOperationType::" + operation.optype,
124        "inputs": [op.model_index for op in operation.ins],
125        "outputs": [op.model_index for op in operation.outs],
126    }
127
128def GetSubgraphStruct(subgraph):
129    """Get the dictionary that corresponds to test_helper::TestSubgraph."""
130    return {
131        COMMENT_KEY: subgraph.name,
132        "operands": [GetOperandStruct(op) for op in subgraph.operands],
133        "operations": [GetOperationStruct(op) for op in subgraph.operations],
134        "inputIndexes": [op.model_index for op in subgraph.GetInputs()],
135        "outputIndexes": [op.model_index for op in subgraph.GetOutputs()],
136    }
137
138def GetModelStruct(example):
139    """Get the dictionary that corresponds to test_helper::TestModel."""
140    return {
141        "main": GetSubgraphStruct(example.model),
142        "referenced": [GetSubgraphStruct(model) for model in example.model.GetReferencedModels()],
143        "isRelaxed": example.model.isRelaxed,
144        "expectedMultinomialDistributionTolerance":
145                example.expectedMultinomialDistributionTolerance,
146        "expectFailure": example.expectFailure,
147        "minSupportedVersion": "TestHalVersion::%s" % (
148                example.model.version if example.model.version is not None else "UNKNOWN"),
149    }
150
151def DumpExample(example, example_fd):
152    assert example.model.compiled
153    template = """\
154namespace generated_tests::{spec_name} {{
155
156const TestModel& get_{example_name}() {{
157    static TestModel model = {aggregate_init};
158    return model;
159}}
160
161const auto dummy_{example_name} = TestModelManager::get().add("{test_name}", get_{example_name}());
162
163}}  // namespace generated_tests::{spec_name}
164"""
165    print(template.format(
166            spec_name=tg.FileNames.specName,
167            test_name=str(example.testName),
168            example_name=str(example.examplesName),
169            aggregate_init=ToCpp(GetModelStruct(example), indent=4),
170        ), file=example_fd)
171
172
173if __name__ == '__main__':
174    ParseCmdLine()
175    tg.Run(InitializeFiles=InitializeFiles, DumpExample=DumpExample)
176