1#!/usr/bin/python3
2
3# Copyright 2019, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Spec Visualizer
18
19Visualize python spec file for test generator.
20Invoked by ml/nn/runtime/test/specs/visualize_spec.sh;
21See that script for details on how this script is used.
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27import argparse
28import json
29import os
30import sys
31from string import Template
32
33# Stuff from test generator
34import test_generator as tg
35from test_generator import ActivationConverter
36from test_generator import BoolScalar
37from test_generator import Configuration
38from test_generator import DataTypeConverter
39from test_generator import DataLayoutConverter
40from test_generator import Example
41from test_generator import Float16Scalar
42from test_generator import Float32Scalar
43from test_generator import Float32Vector
44from test_generator import GetJointStr
45from test_generator import IgnoredOutput
46from test_generator import Input
47from test_generator import Int32Scalar
48from test_generator import Int32Vector
49from test_generator import Internal
50from test_generator import Model
51from test_generator import Operand
52from test_generator import Output
53from test_generator import Parameter
54from test_generator import RelaxedModeConverter
55from test_generator import SymmPerChannelQuantParams
56
57
58TEMPLATE_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "spec_viz_template.html")
59global_graphs = dict()
60
61
62def FormatArray(data, is_scalar=False):
63    if is_scalar:
64        assert len(data) == 1
65        return str(data[0])
66    else:
67        return "[%s]" % (", ".join(str(i) for i in data))
68
69
70def FormatDict(data):
71    return "<br/>".join("<b>%s:</b> %s"%(k.capitalize(), v) for k, v in data.items())
72
73
74def GetOperandInfo(op):
75    op_info = {"lifetime": op.lifetime, "type": op.type.type}
76
77    if not op.type.IsScalar():
78        op_info["dimensions"] = FormatArray(op.type.dimensions)
79
80    if op.type.scale != 0:
81        op_info["scale"] = op.type.scale
82        op_info["zero point"] = op.type.zeroPoint
83    if op.type.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL":
84        op_info["scale"] = FormatArray(op.type.extraParams.scales)
85        op_info["channel dim"] = op.type.extraParams.channelDim
86
87    return op_info
88
89
90def FormatOperand(op):
91    # All keys and values in op_info will appear in the tooltip. We only display the operand data
92    # if the length is less than 10. This should be convenient enough for most parameters.
93    op_info = GetOperandInfo(op)
94    if isinstance(op, Parameter) and len(op.value) <= 10:
95        op_info["data"] = FormatArray(op.value, op.type.IsScalar())
96
97    template = "<span class='tooltip'><span class='tooltipcontent'>{tooltip_content}</span><a href=\"{inpage_link}\">{op_name}</a></span>"
98    return template.format(
99        op_name=str(op),
100        tooltip_content=FormatDict(op_info),
101        inpage_link="#details-operands-%d" % (op.model_index),
102    )
103
104
105def GetSubgraph(example):
106    """Produces the nodes and edges information for d3 visualization."""
107
108    node_index_map = {}
109    topological_order = []
110
111    def AddToTopologicalOrder(op):
112        if op not in node_index_map:
113            node_index_map[op] = len(topological_order)
114            topological_order.append(op)
115
116    # Get the topological order, both operands and operations are treated the same.
117    # Given that the example.model.operations is already topologically sorted, here we simply
118    # iterate through and insert inputs and outputs.
119    for op in example.model.operations:
120        for i in op.ins:
121            AddToTopologicalOrder(i)
122        AddToTopologicalOrder(op)
123        for o in op.outs:
124            AddToTopologicalOrder(o)
125
126    # Assign layers to the nodes.
127    layers = {}
128    for node in topological_order:
129        layers[node] = max([layers[i] for i in node.ins], default=-1) + 1
130    for node in reversed(topological_order):
131        layers[node] = min([layers[o] for o in node.outs], default=layers[node]+1) - 1
132    num_layers = max(layers.values()) + 1
133
134    # Assign coordinates to the nodes. Nodes are equally spaced.
135    CoordX = lambda index: (index + 0.5) * 200  # 200px spacing horizontally
136    CoordY = lambda index: (index + 0.5) * 100  # 100px spacing vertically
137    coords = {}
138    layer_cnt = [0] * num_layers
139    for node in topological_order:
140        coords[node] = (CoordX(layer_cnt[layers[node]]), CoordY(layers[node]))
141        layer_cnt[layers[node]] += 1
142
143    # Create edges and nodes dictionaries for d3 visualization.
144    OpName = lambda idx: "operation%d" % idx
145    edges = []
146    nodes = []
147    for ind, op in enumerate(example.model.operations):
148        for tensor in op.ins:
149            edges.append({
150                "source": str(tensor),
151                "target": OpName(ind)
152            })
153        for tensor in op.outs:
154            edges.append({
155                "target": str(tensor),
156                "source": OpName(ind)
157            })
158        nodes.append({
159            "index": ind,
160            "id": OpName(ind),
161            "name": op.optype,
162            "group": 2,
163            "x": coords[op][0],
164            "y": coords[op][1],
165        })
166
167    for ind, op in enumerate(example.model.operands):
168        nodes.append({
169            "index": ind,
170            "id": str(op),
171            "name": str(op),
172            "group": 1,
173            "x": coords[op][0],
174            "y": coords[op][1],
175        })
176
177    return {"nodes": nodes, "edges": edges}
178
179
180# The following Get**Info methods will each return a list of dictionaries,
181# whose content will appear in the tables and sidebar views.
182def GetConfigurationsInfo(example):
183    return [{
184        "relaxed": str(example.model.isRelaxed),
185        "use shared memory": str(tg.Configuration.useSHM()),
186        "expect failure": str(example.expectFailure),
187    }]
188
189
190def GetOperandsInfo(example):
191    ret = []
192    for index, op in enumerate(example.model.operands):
193        ret.append({
194                "index": index,
195                "name": str(op),
196                "group": "operand"
197            })
198        ret[-1].update(GetOperandInfo(op))
199        if isinstance(op, (Parameter, Input, Output)):
200            ret[-1]["data"] = FormatArray(op.value, op.type.IsScalar())
201    return ret
202
203
204def GetOperationsInfo(example):
205    return [{
206            "index": index,
207            "name": op.optype,
208            "group": "operation",
209            "opcode": op.optype,
210            "inputs": ", ".join(FormatOperand(i) for i in op.ins),
211            "outputs": ", ".join(FormatOperand(o) for o in op.outs),
212        } for index,op in enumerate(example.model.operations)]
213
214
215# TODO: Remove the unused fd from the parameter.
216def ProcessExample(example, fd):
217    """Process an example and save the information into the global dictionary global_graphs."""
218
219    global global_graphs
220    print("    Processing variation %s" % example.testName)
221    global_graphs[str(example.testName)] = {
222        "subgraph": GetSubgraph(example),
223        "details": {
224            "configurations": GetConfigurationsInfo(example),
225            "operands": GetOperandsInfo(example),
226            "operations": GetOperationsInfo(example)
227        }
228    }
229
230
231def DumpHtml(spec_file, out_file):
232    """Dump the final HTML file by replacing entries from a template file."""
233
234    with open(TEMPLATE_FILE, "r") as template_fd:
235        html_template = template_fd.read()
236
237    with open(out_file, "w") as out_fd:
238        out_fd.write(Template(html_template).substitute(
239            spec_name=os.path.basename(spec_file),
240            graph_dump=json.dumps(global_graphs),
241        ))
242
243
244def ParseCmdLine():
245    parser = argparse.ArgumentParser()
246    parser.add_argument("spec", help="the spec file")
247    parser.add_argument("-o", "--out", help="the output html path", default="out.html")
248    args = parser.parse_args()
249    tg.FileNames.InitializeFileLists(args.spec, "-")
250    tg.FileNames.NextFile()
251    return os.path.abspath(args.spec), os.path.abspath(args.out)
252
253
254if __name__ == '__main__':
255    spec_file, out_file = ParseCmdLine()
256    print("Visualizing from spec: %s" % spec_file)
257    exec(open(spec_file, "r").read())
258    Example.DumpAllExamples(DumpExample=ProcessExample, example_fd=0)
259    DumpHtml(spec_file, out_file)
260    print("Output HTML file: %s" % out_file)
261
262