1#!/usr/bin/python3 2 3# Copyright 2017, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""NN model compiler 18 19Contain classes definition and utilify functions for compiling models and 20examples into NDK-based CTS and VTS unit tests. 21 22Used by example_generator.py and spec_visualizer.py 23""" 24 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28import copy 29from functools import reduce 30import argparse 31import io 32import itertools 33import os 34import re 35import sys 36import traceback 37import numpy as np 38 39def GetJointStr(l, sep=", ", method=str): 40 return sep.join([method(i) for i in l]) 41 42# Print in C float literal format 43def PrettyPrintAsFloat(x): 44 s = str(float(x)) 45 if s.find(".") >= 0 or s.find("e") >= 0: 46 return s + "f" 47 else: 48 return s + ".0f" 49 50# Transform from original type to float32 51def Dequantize(v, ty): 52 v -= ty.zeroPoint 53 if ty.scale != 0: 54 v *= ty.scale 55 if isinstance(ty.extraParams, SymmPerChannelQuantParams): 56 v *= ty.extraParams.GetScalesBroadcastArray(ty.dimensions) 57 return v 58 59# Transform float32 to target data type 60def Quantize(v, ty): 61 if ty.scale != 0: 62 v /= ty.scale 63 if isinstance(ty.extraParams, SymmPerChannelQuantParams): 64 v = v / ty.extraParams.GetScalesBroadcastArray(ty.dimensions) 65 v += ty.zeroPoint 66 if not ty.IsFloat(): 67 v = np.round(v) 68 v = v.astype(int) 69 70 if ty.type == "TENSOR_QUANT8_ASYMM": 71 v = np.minimum(np.maximum(v, 0), 255) 72 elif ty.type == "TENSOR_QUANT16_ASYMM": 73 v = np.minimum(np.maximum(v, 0), 65535) 74 elif ty.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL": 75 v = np.minimum(np.maximum(v, -127), 127) 76 elif ty.type == "UINT32": 77 v = np.maximum(v, 0) 78 elif ty.type == "TENSOR_QUANT8_ASYMM_SIGNED": 79 v = np.minimum(np.maximum(v, -128), 127) 80 return v 81 82# Tracking objects inside a model with a unique name 83class NamedObject: 84 existingNames = set() 85 86 def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False): 87 name = GetJointStr([i for i in args if i is not None and i != ""], sep=sep) 88 if skipRenaming: 89 self.name = name 90 return 91 # make the name unique by renaming with a suffix number 92 uniqueName = name if showZero is False else name + sep + str(startsFrom) 93 while uniqueName in self.__class__.existingNames: 94 startsFrom += 1 95 uniqueName = name + sep + str(startsFrom) 96 self.__class__.existingNames.add(uniqueName) 97 self.name = uniqueName 98 99 def __str__(self): 100 return self.name 101 __repr__ = __str__ 102 103 # Since names are unique, objects with the same name are considered equal 104 def __eq__(self, other): 105 return isinstance(other, NamedObject) and self.name == other.name 106 107 def __ne__(self, other): 108 return not self.__eq__(other) 109 110 def __hash__(self): 111 return hash(self.name) 112 113 def __lt__(self, other): 114 return self.name < other.name 115 116# Types, operands should all have a unique name since they share the same namespace 117class NamedVariable(NamedObject): 118 existingNames = set() 119 def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False): 120 NamedObject.__init__(self, *args, sep=sep, showZero=showZero, 121 startsFrom=startsFrom, skipRenaming=skipRenaming) 122 123# Global variables in the spec namespace such as CreateModel, is_ignored, and examples 124class GlobalVariable(NamedVariable): 125 def __init__(self, *args, skipRenaming=False): 126 NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming) 127 128# Each test should have a unique name, but will not conflict with variables 129class NamedTest(NamedObject): 130 existingNames = set() 131 def __init__(self, *args, startsFrom=0, skipRenaming=False): 132 NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming) 133 134class Type(NamedVariable): 135 typesMap = dict() 136 typeLookup = { 137 "INT32": "int32_t", 138 "UINT32": "uint32_t", 139 "FLOAT32": "float", 140 "FLOAT16": "_Float16", 141 "TENSOR_INT32": "int32_t", 142 "TENSOR_FLOAT16": "_Float16", 143 "TENSOR_FLOAT32": "float", 144 "TENSOR_QUANT8_ASYMM": "uint8_t", 145 "TENSOR_QUANT8_SYMM": "int8_t", 146 "BOOL": "bool8", 147 "TENSOR_QUANT16_ASYMM": "uint16_t", 148 "TENSOR_QUANT16_SYMM": "int16_t", 149 "TENSOR_BOOL8": "bool8", 150 "TENSOR_QUANT8_SYMM_PER_CHANNEL": "int8_t", 151 "TENSOR_QUANT8_ASYMM_SIGNED": "int8_t", 152 # "OEM_SCALAR": this is service-defined. 153 "TENSOR_OEM_BYTE": "uint8_t", 154 "SUBGRAPH": "uint32_t", # Index into TestModel::referenced. 155 } 156 157 # types are named as "type0", "type1", ... 158 def __init__(self, vt, dimensions, scale, zeroPoint, name="type", skipRenaming=False, 159 extraParams=None): 160 NamedVariable.__init__(self, name, sep="", showZero=True, skipRenaming=skipRenaming) 161 self.type = vt 162 self.dimensions = dimensions 163 self.scale = float(scale) 164 self.zeroPoint = int(zeroPoint) 165 self.extraParams = extraParams 166 167 # Factory for Type object, only create a new Type if requested type does 168 # not have a match with all existing types 169 @staticmethod 170 def GetType(vt, dimensions, scale=0, zeroPoint=0, extraParams=None): 171 assert isinstance(dimensions, (list, tuple)), \ 172 'dimensions must be a list or tuple, got {}'.format(type(dimensions)) 173 key = ",".join([vt, str(dimensions), str(scale), str(zeroPoint), str(extraParams)]) 174 if key not in Type.typesMap: 175 Type.typesMap[key] = Type(vt, dimensions, scale, zeroPoint, extraParams=extraParams) 176 return Type.typesMap[key] 177 178 @staticmethod 179 def GetAllTypes(): 180 # sort to ensure a stable order when dumping the code 181 return sorted(Type.typesMap.values()) 182 183 # For backward-compatibility 184 @staticmethod 185 def GetTypeFromString(vt, shape, extraParams=None): 186 dimensions, scale, zeroPoint = Type.GetParsedShape(shape) 187 scale = float(scale) 188 zeroPoint = int(zeroPoint) 189 return Type.GetType(vt, dimensions, scale, zeroPoint, extraParams) 190 191 # For backward-compatibility 192 @staticmethod 193 def GetParsedShape(shape): 194 # Parse shape 195 if (shape != "" and shape != "{}"): 196 left, sep, right = shape.partition('{') 197 real_shape, sep, right = right.partition('}') 198 shape = [int(x) for x in real_shape.split(",")] 199 # left now looks like "0.0f, 127.5f, " 200 scale, sep, zero_point = right.rpartition(',') 201 if scale == "": 202 if zero_point == "": 203 return shape, "0", "0" 204 return shape, zero_point, "0" 205 left, sep, scale = scale.partition(',') 206 return shape, scale.replace("f", ""), zero_point 207 else: 208 return [], "0", "0" 209 210 def GetNumberOfElements(self): 211 return reduce(lambda x,y: x*y, self.dimensions, 1) 212 213 def GetCppTypeString(self): 214 return Type.typeLookup[self.type] 215 216 def IsFloat(self): 217 return self.GetCppTypeString() in ["float", "_Float16"] 218 219 def IsBool(self): 220 return self.GetCppTypeString() == "bool8" 221 222 def IsScalar(self): 223 return not self.type.startswith("TENSOR_") 224 225 def GetSignatureTuple(self): 226 return (self.type, self.dimensions, self.scale, self.zeroPoint) 227 228# To track implicitly convertible parameter types 229class ImplicitParameter(): 230 @staticmethod 231 def ImplicitConvertion(value): 232 if isinstance(value, Operand): 233 return value 234 for implicitType in ImplicitParameter.__subclasses__(): 235 if implicitType.IsCompatible(value): 236 return implicitType("param", value) 237 assert False, "%s not supported for implicit parameter"%value 238 239 240# ExtraParams with per-channel quantization. 241class SymmPerChannelQuantParams(): 242 def __init__(self, channelDim, scales, hide = False): 243 self.channelDim = channelDim 244 self.scales = scales 245 self.hide = hide 246 247 def GetScalesBroadcastArray(self, dimensions): 248 bshape = [1] * len(dimensions) 249 bshape[self.channelDim] = len(self.scales) 250 return np.array(self.scales).reshape(bshape) 251 252 253# An operand that can be fed into operations. Also, an operand is always 254# declared before operations. 255class Operand(NamedVariable): 256 257 def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None): 258 NamedVariable.__init__(self, name, sep="", skipRenaming=skipRenaming) 259 if type(opType) is str: 260 self.type = Type.GetTypeFromString(opType, value, extraParams) 261 value = backward 262 else: 263 self.type = Type.GetType(*opType, extraParams=extraParams) 264 self.SetValue(value) 265 self.lifetime = "TEMPORARY_VARIABLE" 266 self.model_index = None 267 self.ins = [] 268 self.outs = [] 269 self.mayBeInternal = True 270 271 def SetValue(self, value): 272 self.value = value if type(value) is list or type(value) is tuple or value is None \ 273 else [value] 274 return self 275 276 def SetValueFromNumpy(self, value): 277 self.value = value.flatten().tolist() 278 return self 279 280 def GetValueAsNumpy(self): 281 return np.array(self.value).reshape(self.type.dimensions) 282 283 # Print value as cpp-style list initialization 284 def GetListInitialization(self): 285 if self.value is None: 286 return "{}" 287 elif self.type.IsFloat(): 288 return "{%s}"%(GetJointStr(self.value, method=PrettyPrintAsFloat)) 289 elif self.type.IsBool(): 290 return "{%s}"%(GetJointStr(self.value, method=lambda v: "true" if v else "false")) 291 else: 292 return "{%s}"%(GetJointStr(self.value, method=lambda x: str(int(x)))) 293 294 def ConvertTo(self, DerivedClass, name=None): 295 assert issubclass(DerivedClass, Operand) 296 name = self.name if name is None else name 297 newop = DerivedClass(name, self.type.GetSignatureTuple(), skipRenaming=True, 298 extraParams=self.type.extraParams) 299 if not issubclass(DerivedClass, Internal): 300 newop.SetValue(self.value) 301 if not self.mayBeInternal: 302 assert not issubclass(DerivedClass, Internal) 303 newop.ShouldNeverBeInternal() 304 return newop 305 306 def ShouldNeverBeInternal(self): 307 self.mayBeInternal = False 308 return self 309 310# Base class of user-defined input/output operand 311class InOut(Operand): 312 313 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 314 Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming, extraParams=extraParams) 315 self.lifetime = "SUBGRAPH_INPUT" 316 self.index = 0 317 318 def Feed(self, value): 319 self.SetValue(value[self] if type(value) is dict else value) 320 return self 321 322# A user-declared input operand 323class Input(InOut): 324 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 325 InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams) 326 self.lifetime = "SUBGRAPH_INPUT" 327 328# A user-declared output operand 329class Output(InOut): 330 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 331 InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams) 332 self.lifetime = "SUBGRAPH_OUTPUT" 333 334# An output that we don't want to compare the results 335class IgnoredOutput(Output): 336 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 337 Output.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams) 338 self.lifetime = "SUBGRAPH_OUTPUT" 339 def Feed(self, value): 340 numElements = reduce(lambda x,y: x*y, self.type.dimensions, 1) 341 self.value = [0 for x in range(numElements)] 342 return self 343 344# An explicitly declared parameter 345class Parameter(Operand): 346 def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None): 347 Operand.__init__(self, name, opType, value, backward, skipRenaming=skipRenaming, 348 extraParams=extraParams) 349 self.initializer = NamedVariable(str(self) + "_init") 350 if value is None: 351 self.lifetime = "NO_VALUE" 352 elif Configuration.useSHM(): 353 self.lifetime = "CONSTANT_REFERENCE" 354 else: 355 self.lifetime = "CONSTANT_COPY" 356 357# A shortcut for parameters of INT32 358class Int32Scalar(Parameter, ImplicitParameter): 359 def __init__(self, name, value): 360 Parameter.__init__(self, name, ("INT32", []), int(value)) 361 @staticmethod 362 def IsCompatible(value): 363 return type(value) is int 364 365# A shortcut for parameters of FLOAT16 366class Float16Scalar(Parameter, ImplicitParameter): 367 def __init__(self, name, value): 368 Parameter.__init__(self, name, ("FLOAT16", []), float(value)) 369 @staticmethod 370 def IsCompatible(value): 371 return False 372 373# A shortcut for parameters of FLOAT32 374class Float32Scalar(Parameter, ImplicitParameter): 375 def __init__(self, name, value): 376 Parameter.__init__(self, name, ("FLOAT32", []), float(value)) 377 @staticmethod 378 def IsCompatible(value): 379 return type(value) is float 380 381# A shortcut for parameters of BOOL 382class BoolScalar(Parameter, ImplicitParameter): 383 def __init__(self, name, value): 384 Parameter.__init__(self, name, ("BOOL", []), bool(value)) 385 @staticmethod 386 def IsCompatible(value): 387 return type(value) is bool 388 389# A shortcut for parameter of 1-D TENSOR_INT32 390class Int32Vector(Parameter, ImplicitParameter): 391 def __init__(self, name, value): 392 Parameter.__init__(self, name, ("TENSOR_INT32", [len(value)]), [int(v) for v in value]) 393 @staticmethod 394 def IsCompatible(value): 395 if type(value) is not list and type(value) is not tuple: 396 return False 397 return all(type(i) is int for i in value) 398 399# A shortcut for parameter of 1-D TENSOR_FLOAT32 400class Float32Vector(Parameter, ImplicitParameter): 401 def __init__(self, name, value): 402 Parameter.__init__(self, name, ("TENSOR_FLOAT32", [len(value)]), [float(v) for v in value]) 403 @staticmethod 404 def IsCompatible(value): 405 if type(value) is not list and type(value) is not tuple: 406 return False 407 return all(type(i) is float for i in value) 408 409# A shortcut for a SUBGRAPH parameter 410class SubgraphReference(Parameter, ImplicitParameter): 411 def __init__(self, name, model): 412 Parameter.__init__(self, name, ("SUBGRAPH", []), model) 413 self.lifetime = "SUBGRAPH" 414 if model.name is None: 415 model.name = name 416 @staticmethod 417 def IsCompatible(value): 418 return type(value) is Model 419 420# An explicitly declared intermediate result 421class Internal(Operand): 422 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 423 Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming, 424 extraParams=extraParams) 425 self.lifetime = "TEMPORARY_VARIABLE" 426 427# An operation in a model, does not need a name 428class Operation: 429 430 def __init__(self, optype, ins, outs): 431 self.optype = optype 432 self.SetInputs(ins) 433 self.SetOutputs(outs) 434 435 # for the ease of debugging 436 def __str__(self): 437 insString = GetJointStr(self.ins) 438 outsString = GetJointStr(self.outs) 439 return "Operation %s: [%s] -> [%s]"%(self.optype, insString, outsString) 440 __repr__ = __str__ 441 442 def SetInputs(self, ins): 443 self.ins = [ImplicitParameter.ImplicitConvertion(i) for i in ins] 444 return self 445 446 def SetOutputs(self, outs): 447 self.outs = list(outs) 448 return self 449 450# Main interface 451class Model: 452 models = list() 453 454 def __init__(self, name=None): 455 self.name = name 456 self.operations = [] 457 self.operands = [] 458 self.isRelaxed = False 459 self.compiled = False 460 self.dumped = False 461 self.version = FileNames.version 462 self.referenced_models = None 463 Model.models.append(self) 464 465 def AddOperand(self, operand): 466 if operand not in self.operands: 467 self.operands.append(operand) 468 return self 469 470 # Makes sure the model contains all (and only) the given inputs in the 471 # specified order. 472 def IdentifyInputs(self, *args): 473 for arg in args: 474 self.AddOperand(arg) 475 inputs = tuple(self.GetInputs()) 476 assert inputs == args, '{} vs {}'.format(inputs, args) 477 return self 478 479 # Makes sure the model contains all (and only) the given outputs in the 480 # specified order. 481 def IdentifyOutputs(self, *args): 482 for arg in args: 483 self.AddOperand(arg) 484 outputs = tuple(self.GetOutputs()) 485 assert outputs == args, '{} vs {}'.format(outputs, args) 486 return self 487 488 def AddOperation(self, operation): 489 self.operations.append(operation) 490 for i in operation.ins: 491 self.AddOperand(i) 492 for o in operation.outs: 493 self.AddOperand(o) 494 return self 495 496 def Operation(self, op_name, *args): 497 return self.AddOperation(Operation(op_name, args, [])) 498 499 def To(self, *args): 500 assert len(self.operations) > 0 501 if type(args[0]) is tuple or type(args[0]) is list: 502 outs = args[0] 503 else: 504 outs = args 505 self.operations[-1].SetOutputs(outs) 506 for o in outs: 507 self.AddOperand(o) 508 return self 509 510 def RelaxedExecution(self, isRelaxed): 511 self.isRelaxed = isRelaxed 512 return self 513 514 # Sets the version of the model in compliance tests. Set to None to disable the test. 515 def IntroducedIn(self, ver): 516 self.version = ver 517 return self 518 519 def GetTypes(self): 520 return sorted(list(set(op.type for op in self.operands))) 521 522 def GetInputs(self): 523 return [i for i in self.operands if isinstance(i, Input)] 524 525 def GetOutputs(self): 526 return [o for o in self.operands if isinstance(o, Output)] 527 528 def GetInputsIndex(self): 529 return [i for i,op in enumerate(self.operands) if isinstance(op, Input)] 530 531 def GetOutputsIndex(self): 532 return [o for o,op in enumerate(self.operands) if isinstance(op, Output)] 533 534 def GetIndexOfOperands(self, operands): 535 return [self.operands.index(i) for i in operands] 536 537 def GetIgnoredOutputs(self): 538 return [o for o in self.operands if isinstance(o, IgnoredOutput)] 539 540 def GetParameters(self): 541 return [p for p in self.operands if isinstance(p, Parameter)] 542 543 def GetReferencedModels(self): 544 assert self.compiled 545 return self.referenced_models 546 547 def GetEquivalentOperands(self, targets): 548 return [self.operands[self.operands.index(t)] for t in targets] 549 550 def UpdateEquivalentOperands(self, targets): 551 for t in targets: 552 self.operands[self.operands.index(t)] = t 553 return self 554 555 def SetOperandIndex(self): 556 for ind, i in enumerate(self.GetInputs()): 557 i.index = ind 558 for ind, o in enumerate(self.GetOutputs()): 559 o.index = ind 560 for ind, op in enumerate(self.operands): 561 op.model_index = ind 562 return self 563 564 def SetOperandInsAndOuts(self): 565 for op in self.operands: 566 op.ins = list() 567 op.outs = list() 568 for op in self.operations: 569 op.ins = self.GetEquivalentOperands(op.ins) 570 op.outs = self.GetEquivalentOperands(op.outs) 571 for i in op.ins: 572 i.outs.append(op) 573 for o in op.outs: 574 o.ins.append(op) 575 return self 576 577 def TopologicalSortHelper(self, op, deps, visited): 578 if op in visited: 579 assert op not in deps, "Cycle detected in the graph" 580 else: 581 visited.add(op) 582 for i in deps[op]: 583 self.TopologicalSortHelper(i, deps, visited) 584 self.operations.append(op) 585 deps.pop(op) 586 587 # Topological sort of the operations, and detect if there is a cycle is the graph 588 def TopologicalSort(self): 589 deps = {op: list() for op in self.operations} 590 [deps[o].append(i) for op in self.operands for o in op.outs for i in op.ins] 591 operations = self.operations.copy() 592 self.operations = [] 593 visited = set() 594 for op in operations: 595 self.TopologicalSortHelper(op, deps, visited) 596 597 def CompileReferencedModels(self, referenced_models, referenced_model_to_index): 598 for operand in self.operands: 599 if operand.lifetime != "SUBGRAPH": 600 continue 601 model = operand.value[0] 602 key = id(model) 603 if key not in referenced_model_to_index: 604 referenced_model_to_index[key] = len(referenced_model_to_index) 605 referenced_models.append(model) 606 model.Compile(referenced_models, referenced_model_to_index) 607 operand.value = [referenced_model_to_index[key]] 608 609 def Compile(self, referenced_models=None, referenced_model_to_index=None): 610 if self.compiled: 611 return self 612 if referenced_models is None: 613 # This is the main model. 614 referenced_models = [] 615 referenced_model_to_index = {} 616 self.referenced_models = referenced_models 617 self.SetOperandIndex() 618 self.SetOperandInsAndOuts() 619 self.TopologicalSort() 620 self.CompileReferencedModels(referenced_models, referenced_model_to_index) 621 # Do not check compliance for relaxed mode tests. 622 if self.isRelaxed: 623 self.IntroducedIn(None) 624 self.compiled = True 625 return self 626 627 def Feed(self, feedDict): 628 for i in self.GetInputs(): 629 i.Feed(feedDict[0]) 630 for o in self.GetOutputs(): 631 o.Feed(feedDict[1]) 632 return self 633 634# To track implicitly convertible variation types 635class ImplicitVariation: 636 @staticmethod 637 def ImplicitConvertion(value): 638 if isinstance(value, ModelVariation): 639 return value 640 for implicitType in ImplicitVariation.__subclasses__(): 641 value = value if type(value) is tuple or type(value) is list else [value] 642 if implicitType.IsCompatible(value[0]): 643 var = implicitType(value[0]) 644 if len(value) > 1: 645 var.Identify(*value[1:]) 646 return var 647 assert False, "%s not supported for implicit variation"%value[0] 648 649# An exception indicating that the current variation list should be skipped. 650class SkipVariation(Exception): 651 pass 652 653# The base class for model variations 654class ModelVariation: 655 supportsSubgraphs = False 656 657 def __init__(self, name=None): 658 self.targetOperands = {} 659 self.name = name 660 661 # Apply the model variation. 662 def ApplyTo(self, model): 663 assert not model.compiled 664 assert not model.dumped 665 666 if not self.supportsSubgraphs: 667 containsSubgraphs = any(operand.lifetime == "SUBGRAPH" for operand in model.operands) 668 assert not containsSubgraphs, "Variation {} does not support subgraphs".format( 669 self.__class__.__name__) 670 671 if not self.targetOperands: 672 self.AutoIdentify(model) 673 674 # Transform operands and model. 675 targets = model.GetEquivalentOperands(sorted(self.targetOperands.keys())) 676 model.UpdateEquivalentOperands( 677 [self.TransformOperand(op, self.targetOperands[op]) for op in targets]) 678 model = self.TransformModel(model) 679 return model 680 681 def IdentifyOperands(self, args=None): 682 if args is None: 683 return self 684 self.targetOperands = args if type(args) is dict else {i: None for i in args} 685 return self 686 687 def Identify(self, operandArgs=None, paramArgs=None): 688 self.IdentifyOperands(operandArgs) 689 return self 690 691 # Set variation to its default name 692 def SetToDefaultName(self): 693 self.name = "" 694 return self 695 696 # Automatically select the target operand list 697 def AutoIdentify(self, model): 698 return self 699 700 # Transform operands that are marked by IdentifyOperands() 701 def TransformOperand(self, op, arg=None): 702 return op 703 704 # Transform the model 705 def TransformModel(self, model): 706 return model 707 708# Default variation that does nothing 709class DefaultVariation(ModelVariation): 710 supportsSubgraphs = True 711 712 def __init__(self, name=None): 713 ModelVariation.__init__(self, name=name) 714 715# Convert operand data type 716class DataTypeConverter(ModelVariation, ImplicitVariation): 717 supportsSubgraphs = True 718 719 def __init__(self, targetType=None, name=None, scale=None, zeroPoint=None): 720 ModelVariation.__init__(self, name=name) 721 if targetType is not None: 722 assert DataTypeConverter.IsCompatible(targetType) 723 self.targetType = targetType 724 self.scale = scale 725 self.zeroPoint = zeroPoint 726 727 @staticmethod 728 def IsCompatible(value): 729 return value.lower() in ["float16", "int32", "quant8", "quant8_signed"] 730 731 def SetToDefaultName(self): 732 if self.targetType is not None: 733 self.name = self.targetType.lower() 734 return self 735 targetTypes = list(zip(*(arg for arg in self.targetOperands.values() 736 if type(arg) is not DataTypeConverter)))[0] 737 if "TENSOR_QUANT8_SYMM_PER_CHANNEL" in targetTypes: 738 self.name = "channelQuant8" 739 elif "TENSOR_QUANT8_ASYMM" in targetTypes: 740 self.name = "quant8" 741 elif "TENSOR_QUANT8_ASYMM_SIGNED" in targetTypes: 742 self.name = "quant8_signed" 743 elif "TENSOR_INT32" in targetTypes: 744 self.name = "int32" 745 elif "TENSOR_FLOAT16" in targetTypes: 746 self.name = "float16" 747 else: 748 self.name = "float32" 749 return self 750 751 def AutoIdentify(self, model): 752 if self.targetType is not None: 753 if self.targetType == "quant8" or self.targetType == "quant8_signed": 754 if self.targetType == "quant8": 755 tensorType = "TENSOR_QUANT8_ASYMM" 756 else: 757 tensorType = "TENSOR_QUANT8_ASYMM_SIGNED" 758 assert self.scale is not None 759 assert self.zeroPoint is not None 760 tensorType = [tensorType, self.scale, self.zeroPoint] 761 scalarType = None # Not supported. 762 else: 763 tensorType = ["TENSOR_" + self.targetType.upper()] 764 scalarType = [self.targetType.upper()] 765 # By default, select all the float32 tensors/scalars 766 targets = dict() 767 targets.update({op: DataTypeConverter(self.targetType, self.name, 768 self.scale, self.zeroPoint) 769 for op in model.operands if op.type.type == "SUBGRAPH"}) 770 targets.update({op: tensorType 771 for op in model.operands if op.type.type == "TENSOR_FLOAT32"}) 772 if scalarType is not None: 773 targets.update({op: scalarType 774 for op in model.operands if op.type.type == "FLOAT32"}) 775 self.Identify(targets) 776 return self 777 778 def TransformOperand(self, op, arg=None): 779 if type(arg) is DataTypeConverter: 780 # Handle nested SUBGRAPHs 781 assert len(op.value) == 1 782 assert type(op.value[0]) is Model 783 op.value[0] = arg.ApplyTo(op.value[0]) 784 return op 785 if len(arg) == 1: 786 typeTuple = (arg[0], op.type.dimensions) 787 else: 788 typeTuple = (arg[0], op.type.dimensions, *arg[1:]) 789 # To handle Internal operands 790 if op.value is None or op.type.GetNumberOfElements() == 0: 791 op.type = Type.GetType(*typeTuple) 792 else: 793 v = Dequantize(op.GetValueAsNumpy().astype(np.float32), op.type) 794 op.type = Type.GetType(*typeTuple) 795 v = Quantize(v, op.type) 796 op.SetValueFromNumpy(v) 797 return op 798 799# Convert model to turn on/off relaxed computation 800class RelaxedModeConverter(ModelVariation, ImplicitVariation): 801 supportsSubgraphs = True 802 803 def __init__(self, isRelaxed=True, name=None): 804 ModelVariation.__init__(self, name=name) 805 if isinstance(isRelaxed, bool): 806 self.isRelaxed = isRelaxed 807 else: 808 assert RelaxedModeConverter.IsCompatible(isRelaxed.lower()) 809 self.isRelaxed = True 810 811 @staticmethod 812 def IsCompatible(value): 813 return value.lower() in ["relaxed"] 814 815 def SetToDefaultName(self): 816 self.name = "relaxed" if self.isRelaxed else "float" 817 return self 818 819 def TransformModel(self, model): 820 model.RelaxedExecution(self.isRelaxed) 821 return model 822 823# Convert data layout between "NHWC" amd "NCHW" 824class DataLayoutConverter(ModelVariation, ImplicitVariation): 825 826 def __init__(self, targetLayout="nchw", name=None): 827 ModelVariation.__init__(self, name=name) 828 self.targetLayout = targetLayout.lower() 829 assert DataLayoutConverter.IsCompatible(self.targetLayout) 830 self.perm = (0, 3, 1, 2) if self.targetLayout == "nchw" else (0, 2, 3, 1) 831 self.param = True if self.targetLayout == "nchw" else False 832 833 @staticmethod 834 def IsCompatible(value): 835 return value.lower() in ["nhwc", "nchw"] 836 837 def SetToDefaultName(self): 838 self.name = self.targetLayout 839 return self 840 841 def TransformOperand(self, op, arg=None): 842 if len(op.type.dimensions) == 4: 843 # To handle Internal operands 844 if op.value is not None and op.type.GetNumberOfElements() != 0: 845 op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm)) 846 newDim = [op.type.dimensions[i] for i in self.perm] 847 op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint) 848 elif len(op.type.dimensions) == 1 and len(op.value) == 4: 849 op.SetValueFromNumpy(op.GetValueAsNumpy()[list(self.perm)]) 850 elif op.type.type == "BOOL": 851 op.SetValue(self.param) 852 else: 853 assert False, "%s not supported by DataLayoutConverter"%op 854 return op 855 856# Convert data by tansposing and removing axis 857class AxisConverter(ModelVariation): 858 859 def __init__(self, origin, target, dim, drop=[], name=None): 860 ModelVariation.__init__(self, name=name) 861 self.origin = origin 862 self.target = target 863 assert all(i >= -dim and i < dim for i in [self.origin, self.target]) 864 self.dim = dim 865 self.perm = list(range(dim)) 866 self.perm.insert(target if target >= 0 else target + dim, self.perm.pop(origin)) 867 self.drop = [drop] if type(drop) is int else list(drop) 868 assert all(i >= -dim and i < dim for i in self.drop) 869 self.drop = [i if i >= 0 else i + dim for i in self.drop] 870 assert target not in self.drop and target + dim not in self.drop 871 872 def SetToDefaultName(self): 873 axis = self.target if self.target >= 0 else self.target + self.dim 874 axis -= sum(i < axis for i in self.drop) 875 neg = "" if self.target >= 0 else "_neg" 876 self.name = "dim%d_axis%d%s"%(self.dim - len(self.drop), axis, neg) 877 return self 878 879 def TransposeAxis(self, op): 880 if op.type.type == "INT32": 881 op.SetValue(self.target) 882 elif len(op.type.dimensions) == self.dim: 883 # To handle Internal operands 884 if op.value is not None: 885 op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm)) 886 newDim = [op.type.dimensions[i] for i in self.perm] 887 op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint) 888 else: 889 assert False, "%s not supported by AxisConverter"%op 890 return op 891 892 def RemoveAxis(self, op): 893 if op.type.type == "INT32": 894 if op.value[0] >= 0: 895 op.SetValue(op.value[0] - sum(i < op.value[0] for i in self.drop)) 896 else: 897 op.SetValue(op.value[0] + sum(i > (op.value[0] + self.dim) for i in self.drop)) 898 elif len(op.type.dimensions) == self.dim: 899 if op.value is not None: 900 val = op.GetValueAsNumpy() 901 for i in sorted(self.drop, reverse=True): 902 val = np.take(val, 0, axis=i) 903 op.SetValueFromNumpy(val) 904 newDim = [op.type.dimensions[i] for i in range(self.dim) if i not in self.drop] 905 op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint) 906 else: 907 assert False, "%s not supported by AxisConverter"%op 908 return op 909 910 def TransformOperand(self, op, arg=None): 911 op = self.TransposeAxis(op) 912 op = self.RemoveAxis(op) 913 return op 914 915# Convert Output based on activation 916class ActivationConverter(ModelVariation, ImplicitVariation): 917 # (Enum, low, high) 918 actMap = { 919 "none": (0, None, None), 920 "relu": (1, 0.0, None), 921 "relu1": (2, -1.0, 1.0), 922 "relu6": (3, 0.0, 6.0), 923 } 924 def __init__(self, act="relu", name=None): 925 ModelVariation.__init__(self, name=name) 926 self.act = act.lower() 927 assert ActivationConverter.IsCompatible(self.act) 928 self.enum = ActivationConverter.actMap[self.act][0] 929 self.low = ActivationConverter.actMap[self.act][1] 930 self.high = ActivationConverter.actMap[self.act][2] 931 932 @staticmethod 933 def IsCompatible(value): 934 return value.lower() in ActivationConverter.actMap.keys() 935 936 def SetToDefaultName(self): 937 self.name = self.act 938 return self 939 940 def TransformOperand(self, op, arg=None): 941 if op.type.type == "INT32": # activation enum 942 return op.SetValue(self.enum) 943 else: 944 assert isinstance(op, Output) 945 v = op.GetValueAsNumpy() 946 if self.low is not None: 947 low = Quantize(self.low, op.type) 948 v = np.maximum(v, low) 949 if self.high is not None: 950 high = Quantize(self.high, op.type) 951 v = np.minimum(v, high) 952 return op.SetValueFromNumpy(v) 953 954# Convert all constant tensors as model inputs. 955class AllTensorsAsInputsConverter(ModelVariation): 956 supportsSubgraphs = True 957 958 def __init__(self, name=None): 959 ModelVariation.__init__(self, name=name) 960 961 def SetToDefaultName(self): 962 self.name = "all_tensors_as_inputs" 963 return self 964 965 def TransformModel(self, model): 966 if len(model.operations) != 1: 967 raise SkipVariation 968 969 # Find all constant tensors. 970 tensorParams = [ 971 p for p in model.operands 972 if type(p) is Parameter and not p.type.IsScalar() and p.value is not None 973 ] 974 if not tensorParams: 975 raise SkipVariation 976 977 # Convert to model inputs. 978 model.UpdateEquivalentOperands([op.ConvertTo(Input) for op in tensorParams]) 979 return model 980 981def CompatibleWithADD(op): 982 return (len(op.type.dimensions) <= 4 and 983 len(op.value) > 0 and 984 op.type.type in ["TENSOR_FLOAT32", "TENSOR_QUANT8_ASYMM", 985 "TENSOR_FLOAT16", "TENSOR_QUANT8_ASYMM_SIGNED"]) 986 987# Add a placeholder ADD operation before each model input to make it as an internal operand. 988class AllInputsAsInternalCoverter(ModelVariation): 989 supportsSubgraphs = True 990 991 def __init__(self, name=None): 992 ModelVariation.__init__(self, name=name) 993 994 def SetToDefaultName(self): 995 self.name = "all_inputs_as_internal" 996 return self 997 998 def TransformModel(self, model): 999 if len(model.operations) != 1: 1000 raise SkipVariation 1001 1002 # Find all input tensors that can be an output of the ADD operation. 1003 modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i) and i.mayBeInternal] 1004 if not modelInputs: 1005 raise SkipVariation 1006 1007 # Make every input an output of a placeholder operation: input_new ADD placeholder = input. 1008 for op in modelInputs: 1009 newInput = op.ConvertTo(Input, name=op.name + "_new") 1010 placeholderParam = Parameter("placeholder", 1011 (op.type.type, [1], op.type.scale, op.type.zeroPoint), 1012 [op.type.zeroPoint]) 1013 model.Operation("ADD", newInput, placeholderParam, 0).To(op) 1014 1015 # Convert to internal operands. 1016 model.UpdateEquivalentOperands([op.ConvertTo(Internal) for op in modelInputs]) 1017 return model 1018 1019# Add a placeholder ADD operation after each model output to make it as an internal operand. 1020class AllOutputsAsInternalCoverter(ModelVariation): 1021 supportsSubgraphs = True 1022 1023 def __init__(self, name=None): 1024 ModelVariation.__init__(self, name=name) 1025 1026 def SetToDefaultName(self): 1027 self.name = "all_outputs_as_internal" 1028 return self 1029 1030 def TransformModel(self, model): 1031 if len(model.operations) != 1: 1032 raise SkipVariation 1033 1034 # Find all output tensors that can be an input to an ADD operation. 1035 modelOutputs = [o for o in model.GetOutputs() if CompatibleWithADD(o)] 1036 if not modelOutputs: 1037 raise SkipVariation 1038 1039 # Make every output an input of a placeholder operation: output ADD placeholder = output_new. 1040 for op in modelOutputs: 1041 newOutput = op.ConvertTo(Output, name=op.name + "_new") 1042 placeholderParam = Parameter("placeholder", 1043 (op.type.type, [1], op.type.scale, op.type.zeroPoint), 1044 [op.type.zeroPoint]) 1045 model.Operation("ADD", op, placeholderParam, 0).To(newOutput) 1046 1047 # Convert to internal operands. 1048 model.UpdateEquivalentOperands([op.ConvertTo(Internal) for op in modelOutputs]) 1049 return model 1050 1051# An example is always attached to a model, and could have multiple variations 1052class Example: 1053 examples = [] 1054 versionOverrides = {} 1055 1056 def __init__(self, *args, model=None, name=None): 1057 self.model = Model.models[-1] if model is None else model 1058 self.name = name 1059 self.expectedMultinomialDistributionTolerance = 0 1060 self.expectFailure = False 1061 self.testLifeTimeVariation = True 1062 self.feedDicts = [] 1063 for feedDict in args: 1064 if type(feedDict) is tuple or type(feedDict) is list: 1065 self.feedDicts.append(feedDict) 1066 elif type(feedDict) is dict: 1067 self.feedDicts.append(( 1068 {i: feedDict[i] for i in self.model.GetInputs()}, 1069 {o: feedDict[o] for o in self.model.GetOutputs()} 1070 )) 1071 else: 1072 assert False 1073 self.variations = [] 1074 Example.examples.append(self) 1075 1076 @staticmethod 1077 def SetVersion(ver, *args): 1078 for name in args: 1079 Example.versionOverrides[name] = ver 1080 1081 # Main entrance of test generator 1082 @staticmethod 1083 def DumpAllExamples(DumpModel=None, model_fd=None, 1084 DumpExample=None, example_fd=None, 1085 DumpTest=None, test_fd=None): 1086 Example.CombineAllExamples() 1087 for example in Example.examples: 1088 example.Dump(DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd) 1089 1090 # Combine examples with the same model, same name, and same set of variations 1091 @staticmethod 1092 def CombineAllExamples(): 1093 modelMap = {} 1094 newExamples = [] 1095 for example in Example.examples: 1096 key = (example.model, example.name, tuple(tuple(e) for e in example.variations)) 1097 if key in modelMap: 1098 modelMap[key].Combine(example) 1099 else: 1100 modelMap[key] = example 1101 newExamples.append(example) 1102 Example.examples = newExamples 1103 1104 def AddVariations(self, *args, includeDefault=True, defaultName=None): 1105 self.variations.append([DefaultVariation(defaultName)] if includeDefault else []) 1106 self.variations[-1].extend(ImplicitVariation.ImplicitConvertion(i) for i in args) 1107 return self 1108 1109 def AddNchw(self, *args, includeDefault=True, defaultName="nhwc"): 1110 var = DataLayoutConverter("nchw").Identify(args) 1111 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1112 return self 1113 1114 def AddRelaxed(self, isRelaxed=True, includeDefault=True, defaultName=None): 1115 var = RelaxedModeConverter(isRelaxed) 1116 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1117 return self 1118 1119 def AddRelu(self, *args, includeDefault=True, defaultName=None): 1120 var = ActivationConverter("relu").Identify(args) 1121 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1122 return self 1123 1124 def AddAllActivations(self, *args): 1125 var = [ActivationConverter(i).Identify(args) 1126 for i in sorted(ActivationConverter.actMap.keys())] 1127 self.AddVariations(*var, includeDefault=False) 1128 return self 1129 1130 def GuessOriginalAxisAndDim(self, *args): 1131 origin = None 1132 dim = None 1133 for arg in args: 1134 if arg.type.type == "INT32": 1135 origin = arg.value[0] 1136 else: 1137 if dim is None: 1138 dim = len(arg.type.dimensions) 1139 else: 1140 assert dim == len(arg.type.dimensions) 1141 assert dim is not None 1142 origin = dim - 1 if origin is None else origin 1143 origin = origin + dim if origin < 0 else origin 1144 return origin, dim 1145 1146 def AddAxis(self, axis, *args, includeDefault=True, defaultName=None): 1147 origin, dim = self.GuessOriginalAxisAndDim(*args) 1148 axis = [axis] if type(axis) is int else list(axis) 1149 var = [AxisConverter(origin, a, dim).Identify(args) for a in axis] 1150 self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName) 1151 return self 1152 1153 def AddAllPositiveAxis(self, *args): 1154 origin, dim = self.GuessOriginalAxisAndDim(*args) 1155 var = [AxisConverter(origin, a, dim).Identify(args) for a in range(dim)] 1156 self.AddVariations(*var, includeDefault=False) 1157 return self 1158 1159 def AddAllAxis(self, *args): 1160 origin, dim = self.GuessOriginalAxisAndDim(*args) 1161 var = [AxisConverter(origin, a, dim).Identify(args) for a in range(-dim, dim)] 1162 self.AddVariations(*var, includeDefault=False) 1163 return self 1164 1165 def AddDims(self, dims, *args, includeDefault=True, defaultName=None): 1166 origin, dim = self.GuessOriginalAxisAndDim(*args) 1167 dims = [dims] if type(dims) is int else list(dims) 1168 drop = list(range(dim)) 1169 drop.pop(origin) 1170 var = [AxisConverter(origin, origin, dim, drop[0:(dim-i)]).Identify(args) for i in dims] 1171 self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName) 1172 return self 1173 1174 def AddAllDims(self, *args): 1175 origin, dim = self.GuessOriginalAxisAndDim(*args) 1176 drop = list(range(dim)) 1177 drop.pop(origin) 1178 var = [AxisConverter(origin, origin, dim, drop[0:i]).Identify(args) for i in range(dim)] 1179 self.AddVariations(*var, includeDefault=False) 1180 return self 1181 1182 def AddAllDimsAndPositiveAxis(self, *args): 1183 origin, dim = self.GuessOriginalAxisAndDim(*args) 1184 var = [AxisConverter(origin, j, dim, range(i)).Identify(args) \ 1185 for i in range(dim) for j in range(i, dim)] 1186 self.AddVariations(*var, includeDefault=False) 1187 return self 1188 1189 def AddAllDimsAndAxis(self, *args): 1190 origin, dim = self.GuessOriginalAxisAndDim(*args) 1191 var = [AxisConverter(origin, k, dim, range(i)).Identify(args) \ 1192 for i in range(dim) for j in range(i, dim) for k in [j, j - dim]] 1193 self.AddVariations(*var, includeDefault=False) 1194 return self 1195 1196 def Combine(self, other): 1197 assert self.model is other.model, "Only examples targetting the same model can be combined" 1198 assert tuple(self.variations) == tuple(other.variations), \ 1199 "Only examples with the same set of variations can be combined" 1200 assert self.name == other.name, "Only examples with the same name can be combined" 1201 self.feedDicts.extend(other.feedDicts) 1202 return self 1203 1204 def Dump(self, DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd): 1205 if self.testLifeTimeVariation and len(self.model.operations) == 1 and \ 1206 self.expectedMultinomialDistributionTolerance == 0: 1207 self.AddVariations(AllTensorsAsInputsConverter()) 1208 self.AddVariations(AllInputsAsInternalCoverter()) 1209 [v.SetToDefaultName() for vs in self.variations for v in vs if v.name is None] 1210 1211 for feedDict in self.feedDicts: 1212 self.model.Feed(feedDict) 1213 for variationList in itertools.product(*self.variations): 1214 modelOrigin = self.model 1215 self.model = copy.deepcopy(self.model) 1216 1217 # Apply variations 1218 try: 1219 for variation in variationList: 1220 self.model = variation.ApplyTo(self.model) 1221 except SkipVariation: 1222 self.model = modelOrigin 1223 continue 1224 1225 # Concat names for test and examples 1226 varNames = [v.name for v in variationList] 1227 self.testName = NamedTest(FileNames.specName, self.model.name, self.name, *varNames) 1228 self.examplesName = GlobalVariable("test_model", self.model.name, self.name, 1229 *varNames) 1230 if str(self.testName) in Example.versionOverrides: 1231 self.model.IntroducedIn(Example.versionOverrides[str(self.testName)]) 1232 self.model.Compile() 1233 1234 # Dump files 1235 if DumpExample is not None and example_fd is not None: 1236 DumpExample(self, example_fd) 1237 if DumpTest is not None and test_fd is not None: 1238 DumpTest(self, test_fd) 1239 1240 # Restore model before variation 1241 self.model = modelOrigin 1242 return self 1243 1244 # Specifies the RANDOM_MULTINOMIAL distribution tolerance. 1245 # If set to greater than zero, the input is compared as log-probabilities 1246 # to the output and must be within this tolerance to pass. 1247 def WithMultinomialDistributionTolerance(self, expectedTolerance): 1248 assert self.expectFailure is False 1249 self.expectedMultinomialDistributionTolerance = expectedTolerance 1250 return self 1251 1252 # Specifies that this example is expected to fail during compilation or execution. 1253 def ExpectFailure(self): 1254 assert self.expectedMultinomialDistributionTolerance == 0 1255 self.expectFailure = True 1256 return self 1257 1258 def DisableLifeTimeVariation(self): 1259 self.testLifeTimeVariation = False 1260 return self 1261 1262class FileNames: 1263 specFiles = [] 1264 specNames = [] 1265 exampleFiles = [] 1266 specFile = "" 1267 specName = "" 1268 exampleFile = "" 1269 version = "" 1270 fileIndex = 0 1271 1272 @staticmethod 1273 def InitializeFileLists(spec, example): 1274 # get all spec files and target files 1275 if os.path.isfile(spec): 1276 FileNames.specFiles = [os.path.abspath(spec)] 1277 elif os.path.isdir(spec): 1278 FileNames.specFiles = sorted([os.path.abspath(os.path.join(spec, f)) 1279 for f in os.listdir(spec) if f.endswith(".mod.py")]) 1280 else: 1281 assert False, "%s is neither a file or a directory"%spec 1282 FileNames.specNames = [re.sub(r"\..*", "", os.path.basename(f)) 1283 for f in FileNames.specFiles] 1284 FileNames.exampleFiles = FileNames.ParseTargetFiles(example, ".example.cpp") 1285 1286 @staticmethod 1287 def ParseTargetFiles(arg, ext): 1288 numFiles = len(FileNames.specFiles) 1289 if arg is None: 1290 return [None] * numFiles 1291 absPath = os.path.abspath(arg) 1292 if os.path.isdir(arg): 1293 target = [os.path.join(absPath, f + ext) for f in FileNames.specNames] 1294 elif arg == "-": 1295 target = ["-"] * numFiles 1296 else: 1297 target = [absPath] * numFiles 1298 return target 1299 1300 @staticmethod 1301 def NextFile(): 1302 if FileNames.fileIndex >= len(FileNames.specFiles): 1303 return False 1304 FileNames.specFile = FileNames.specFiles[FileNames.fileIndex] 1305 FileNames.specName = FileNames.specNames[FileNames.fileIndex] 1306 FileNames.exampleFile = FileNames.exampleFiles[FileNames.fileIndex] 1307 FileNames.fileIndex += 1 1308 NamedObject.existingNames = set() 1309 NamedVariable.existingNames = set() 1310 NamedTest.existingNames = set() 1311 Type.typesMap = dict() 1312 Model.models = list() 1313 Example.examples = list() 1314 Configuration.use_shm_for_weights = False 1315 1316 # Extract version from absolute file path. 1317 versionMatch = re.findall(r"/V\d_\d/", FileNames.specFile) 1318 if len(versionMatch) == 1: 1319 FileNames.version = versionMatch[0].strip('/') 1320 else: 1321 FileNames.version = None 1322 return True 1323 1324class Configuration: 1325 use_shm_for_weights = False 1326 hook_mode = False 1327 1328 @staticmethod 1329 def useSHM(): 1330 return Configuration.use_shm_for_weights 1331 1332def GetTestGeneratorMTime(): 1333 tgFiles = ['test_generator.py', 'example_generator.py'] 1334 tgDir = os.path.dirname(__file__) 1335 return max(os.path.getmtime(os.path.join(tgDir, filename)) 1336 for filename in tgFiles) 1337 1338def MightNeedRegeneration(): 1339 specTime = os.path.getmtime(FileNames.specFile) 1340 tgTime = GetTestGeneratorMTime() 1341 return not os.path.exists(FileNames.exampleFile) or \ 1342 os.path.getmtime(FileNames.exampleFile) <= max(specTime, tgTime) 1343 1344def Read(filename): 1345 with open(filename) as reader: 1346 return reader.read() 1347 1348def AtomicWrite(filename, data): 1349 # os.replace(src, dest) may fail if src and dest are on diffrent 1350 # filesystems. 1351 tempFile = filename + '.tmp' 1352 try: 1353 with open(tempFile, 'w') as writer: 1354 writer.write(data) 1355 os.replace(tempFile, filename) 1356 tempFile = None 1357 finally: 1358 if tempFile is not None and os.path.exists(tempFile): 1359 os.remove(tempFile) 1360 1361def GetExecScope(): 1362 return dict( 1363 ActivationConverter=ActivationConverter, 1364 AllInputsAsInternalCoverter=AllInputsAsInternalCoverter, 1365 AllOutputsAsInternalCoverter=AllOutputsAsInternalCoverter, 1366 AllTensorsAsInputsConverter=AllTensorsAsInputsConverter, 1367 BoolScalar=BoolScalar, 1368 Configuration=Configuration, 1369 DataLayoutConverter=DataLayoutConverter, 1370 DataTypeConverter=DataTypeConverter, 1371 Example=Example, 1372 Float16Scalar=Float16Scalar, 1373 Float32Scalar=Float32Scalar, 1374 Float32Vector=Float32Vector, 1375 IgnoredOutput=IgnoredOutput, 1376 Input=Input, 1377 Int32Scalar=Int32Scalar, 1378 Int32Vector=Int32Vector, 1379 Internal=Internal, 1380 Model=Model, 1381 Operand=Operand, 1382 Output=Output, 1383 Parameter=Parameter, 1384 RelaxedModeConverter=RelaxedModeConverter, 1385 SubgraphReference=SubgraphReference, 1386 SymmPerChannelQuantParams=SymmPerChannelQuantParams) 1387 1388def ArgumentParser(): 1389 parser = argparse.ArgumentParser() 1390 parser.add_argument("spec", help="the spec file or directory") 1391 parser.add_argument("--hook", help="hook mode", action='store_true') 1392 return parser 1393 1394def ParseArgs(parser): 1395 args = parser.parse_args() 1396 Configuration.hook_mode = args.hook 1397 return args 1398 1399def Run(InitializeFiles=None, DumpExample=None): 1400 exec_scope = GetExecScope() 1401 while FileNames.NextFile(): 1402 try: 1403 if not MightNeedRegeneration(): 1404 continue 1405 exec(Read(FileNames.specFile), exec_scope) 1406 example_buf = io.StringIO() if FileNames.exampleFile else None 1407 InitializeFiles(example_fd=example_buf) 1408 Example.DumpAllExamples(DumpExample=DumpExample, example_fd=example_buf) 1409 if FileNames.exampleFile is None: 1410 continue 1411 if Configuration.hook_mode and (not os.path.exists(FileNames.exampleFile) or 1412 Read(FileNames.exampleFile) != example_buf.getvalue()): 1413 print(('\n{filename} is out of date. ' 1414 'Please run {generate_all_tests_sh} before uploading.\n').format( 1415 filename=FileNames.exampleFile, 1416 generate_all_tests_sh=os.path.abspath(os.path.join( 1417 os.path.dirname(__file__), '..', '..', 'runtime', 'test', 1418 'specs', 'generate_all_tests.sh')))) 1419 sys.exit(1) 1420 AtomicWrite(FileNames.exampleFile, example_buf.getvalue()) 1421 except Exception: 1422 traceback.print_exc() 1423 sys.exit("Exception raised when processing {}".format(FileNames.specFile)) 1424