1#!/usr/bin/python3 2 3# Copyright 2019, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Spec Visualizer 18 19Visualize python spec file for test generator. 20Invoked by ml/nn/runtime/test/specs/visualize_spec.sh; 21See that script for details on how this script is used. 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27import argparse 28import json 29import os 30import sys 31from string import Template 32 33# Stuff from test generator 34import test_generator as tg 35from test_generator import ActivationConverter 36from test_generator import BoolScalar 37from test_generator import Configuration 38from test_generator import DataTypeConverter 39from test_generator import DataLayoutConverter 40from test_generator import Example 41from test_generator import Float16Scalar 42from test_generator import Float32Scalar 43from test_generator import Float32Vector 44from test_generator import GetJointStr 45from test_generator import IgnoredOutput 46from test_generator import Input 47from test_generator import Int32Scalar 48from test_generator import Int32Vector 49from test_generator import Internal 50from test_generator import Model 51from test_generator import Operand 52from test_generator import Output 53from test_generator import Parameter 54from test_generator import RelaxedModeConverter 55from test_generator import SymmPerChannelQuantParams 56 57 58TEMPLATE_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "spec_viz_template.html") 59global_graphs = dict() 60 61 62def FormatArray(data, is_scalar=False): 63 if is_scalar: 64 assert len(data) == 1 65 return str(data[0]) 66 else: 67 return "[%s]" % (", ".join(str(i) for i in data)) 68 69 70def FormatDict(data): 71 return "<br/>".join("<b>%s:</b> %s"%(k.capitalize(), v) for k, v in data.items()) 72 73 74def GetOperandInfo(op): 75 op_info = {"lifetime": op.lifetime, "type": op.type.type} 76 77 if not op.type.IsScalar(): 78 op_info["dimensions"] = FormatArray(op.type.dimensions) 79 80 if op.type.scale != 0: 81 op_info["scale"] = op.type.scale 82 op_info["zero point"] = op.type.zeroPoint 83 if op.type.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL": 84 op_info["scale"] = FormatArray(op.type.extraParams.scales) 85 op_info["channel dim"] = op.type.extraParams.channelDim 86 87 return op_info 88 89 90def FormatOperand(op): 91 # All keys and values in op_info will appear in the tooltip. We only display the operand data 92 # if the length is less than 10. This should be convenient enough for most parameters. 93 op_info = GetOperandInfo(op) 94 if isinstance(op, Parameter) and len(op.value) <= 10: 95 op_info["data"] = FormatArray(op.value, op.type.IsScalar()) 96 97 template = "<span class='tooltip'><span class='tooltipcontent'>{tooltip_content}</span><a href=\"{inpage_link}\">{op_name}</a></span>" 98 return template.format( 99 op_name=str(op), 100 tooltip_content=FormatDict(op_info), 101 inpage_link="#details-operands-%d" % (op.model_index), 102 ) 103 104 105def GetSubgraph(example): 106 """Produces the nodes and edges information for d3 visualization.""" 107 108 node_index_map = {} 109 topological_order = [] 110 111 def AddToTopologicalOrder(op): 112 if op not in node_index_map: 113 node_index_map[op] = len(topological_order) 114 topological_order.append(op) 115 116 # Get the topological order, both operands and operations are treated the same. 117 # Given that the example.model.operations is already topologically sorted, here we simply 118 # iterate through and insert inputs and outputs. 119 for op in example.model.operations: 120 for i in op.ins: 121 AddToTopologicalOrder(i) 122 AddToTopologicalOrder(op) 123 for o in op.outs: 124 AddToTopologicalOrder(o) 125 126 # Assign layers to the nodes. 127 layers = {} 128 for node in topological_order: 129 layers[node] = max([layers[i] for i in node.ins], default=-1) + 1 130 for node in reversed(topological_order): 131 layers[node] = min([layers[o] for o in node.outs], default=layers[node]+1) - 1 132 num_layers = max(layers.values()) + 1 133 134 # Assign coordinates to the nodes. Nodes are equally spaced. 135 CoordX = lambda index: (index + 0.5) * 200 # 200px spacing horizontally 136 CoordY = lambda index: (index + 0.5) * 100 # 100px spacing vertically 137 coords = {} 138 layer_cnt = [0] * num_layers 139 for node in topological_order: 140 coords[node] = (CoordX(layer_cnt[layers[node]]), CoordY(layers[node])) 141 layer_cnt[layers[node]] += 1 142 143 # Create edges and nodes dictionaries for d3 visualization. 144 OpName = lambda idx: "operation%d" % idx 145 edges = [] 146 nodes = [] 147 for ind, op in enumerate(example.model.operations): 148 for tensor in op.ins: 149 edges.append({ 150 "source": str(tensor), 151 "target": OpName(ind) 152 }) 153 for tensor in op.outs: 154 edges.append({ 155 "target": str(tensor), 156 "source": OpName(ind) 157 }) 158 nodes.append({ 159 "index": ind, 160 "id": OpName(ind), 161 "name": op.optype, 162 "group": 2, 163 "x": coords[op][0], 164 "y": coords[op][1], 165 }) 166 167 for ind, op in enumerate(example.model.operands): 168 nodes.append({ 169 "index": ind, 170 "id": str(op), 171 "name": str(op), 172 "group": 1, 173 "x": coords[op][0], 174 "y": coords[op][1], 175 }) 176 177 return {"nodes": nodes, "edges": edges} 178 179 180# The following Get**Info methods will each return a list of dictionaries, 181# whose content will appear in the tables and sidebar views. 182def GetConfigurationsInfo(example): 183 return [{ 184 "relaxed": str(example.model.isRelaxed), 185 "use shared memory": str(tg.Configuration.useSHM()), 186 "expect failure": str(example.expectFailure), 187 }] 188 189 190def GetOperandsInfo(example): 191 ret = [] 192 for index, op in enumerate(example.model.operands): 193 ret.append({ 194 "index": index, 195 "name": str(op), 196 "group": "operand" 197 }) 198 ret[-1].update(GetOperandInfo(op)) 199 if isinstance(op, (Parameter, Input, Output)): 200 ret[-1]["data"] = FormatArray(op.value, op.type.IsScalar()) 201 return ret 202 203 204def GetOperationsInfo(example): 205 return [{ 206 "index": index, 207 "name": op.optype, 208 "group": "operation", 209 "opcode": op.optype, 210 "inputs": ", ".join(FormatOperand(i) for i in op.ins), 211 "outputs": ", ".join(FormatOperand(o) for o in op.outs), 212 } for index,op in enumerate(example.model.operations)] 213 214 215# TODO: Remove the unused fd from the parameter. 216def ProcessExample(example, fd): 217 """Process an example and save the information into the global dictionary global_graphs.""" 218 219 global global_graphs 220 print(" Processing variation %s" % example.testName) 221 global_graphs[str(example.testName)] = { 222 "subgraph": GetSubgraph(example), 223 "details": { 224 "configurations": GetConfigurationsInfo(example), 225 "operands": GetOperandsInfo(example), 226 "operations": GetOperationsInfo(example) 227 } 228 } 229 230 231def DumpHtml(spec_file, out_file): 232 """Dump the final HTML file by replacing entries from a template file.""" 233 234 with open(TEMPLATE_FILE, "r") as template_fd: 235 html_template = template_fd.read() 236 237 with open(out_file, "w") as out_fd: 238 out_fd.write(Template(html_template).substitute( 239 spec_name=os.path.basename(spec_file), 240 graph_dump=json.dumps(global_graphs), 241 )) 242 243 244def ParseCmdLine(): 245 parser = argparse.ArgumentParser() 246 parser.add_argument("spec", help="the spec file") 247 parser.add_argument("-o", "--out", help="the output html path", default="out.html") 248 args = parser.parse_args() 249 tg.FileNames.InitializeFileLists(args.spec, "-") 250 tg.FileNames.NextFile() 251 return os.path.abspath(args.spec), os.path.abspath(args.out) 252 253 254if __name__ == '__main__': 255 spec_file, out_file = ParseCmdLine() 256 print("Visualizing from spec: %s" % spec_file) 257 exec(open(spec_file, "r").read()) 258 Example.DumpAllExamples(DumpExample=ProcessExample, example_fd=0) 259 DumpHtml(spec_file, out_file) 260 print("Output HTML file: %s" % out_file) 261 262