1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5# See https://llvm.org/LICENSE.txt for license information. 6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 8# Script for updating SPIR-V dialect by scraping information from SPIR-V 9# HTML and JSON specs from the Internet. 10# 11# For example, to define the enum attribute for SPIR-V memory model: 12# 13# ./gen_spirv_dialect.py --base-td-path /path/to/SPIRVBase.td \ 14# --new-enum MemoryModel 15# 16# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported 17# SPIR-V enum classes. 18 19import itertools 20import math 21import re 22import requests 23import textwrap 24import yaml 25 26SPIRV_HTML_SPEC_URL = ( 27 "https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html" 28) 29SPIRV_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json" 30 31SPIRV_CL_EXT_HTML_SPEC_URL = "https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html" 32SPIRV_CL_EXT_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json" 33 34AUTOGEN_OP_DEF_SEPARATOR = "\n// -----\n\n" 35AUTOGEN_ENUM_SECTION_MARKER = "enum section. Generated from SPIR-V spec; DO NOT MODIFY!" 36AUTOGEN_OPCODE_SECTION_MARKER = ( 37 "opcode section. Generated from SPIR-V spec; DO NOT MODIFY!" 38) 39 40 41def get_spirv_doc_from_html_spec(url, settings): 42 """Extracts instruction documentation from SPIR-V HTML spec. 43 44 Returns: 45 - A dict mapping from instruction opcode to documentation. 46 """ 47 if url is None: 48 url = SPIRV_HTML_SPEC_URL 49 50 response = requests.get(url) 51 spec = response.content 52 53 from bs4 import BeautifulSoup 54 55 spirv = BeautifulSoup(spec, "html.parser") 56 57 doc = {} 58 59 if settings.gen_cl_ops: 60 section_anchor = spirv.find("h2", {"id": "_binary_form"}) 61 for section in section_anchor.parent.find_all("div", {"class": "sect2"}): 62 for table in section.find_all("table"): 63 inst_html = table.tbody.tr.td 64 opname = inst_html.a["id"] 65 # Ignore the first line, which is just the opname. 66 doc[opname] = inst_html.text.split("\n", 1)[1].strip() 67 else: 68 section_anchor = spirv.find("h3", {"id": "_instructions_3"}) 69 for section in section_anchor.parent.find_all("div", {"class": "sect3"}): 70 for table in section.find_all("table"): 71 inst_html = table.tbody.tr.td.p 72 opname = inst_html.a["id"] 73 # Ignore the first line, which is just the opname. 74 doc[opname] = inst_html.text.split("\n", 1)[1].strip() 75 76 return doc 77 78 79def get_spirv_grammar_from_json_spec(url): 80 """Extracts operand kind and instruction grammar from SPIR-V JSON spec. 81 82 Returns: 83 - A list containing all operand kinds' grammar 84 - A list containing all instructions' grammar 85 """ 86 response = requests.get(SPIRV_JSON_SPEC_URL) 87 spec = response.content 88 89 import json 90 91 spirv = json.loads(spec) 92 93 if url is None: 94 return spirv["operand_kinds"], spirv["instructions"] 95 96 response_ext = requests.get(url) 97 spec_ext = response_ext.content 98 spirv_ext = json.loads(spec_ext) 99 100 return spirv["operand_kinds"], spirv_ext["instructions"] 101 102 103def split_list_into_sublists(items): 104 """Split the list of items into multiple sublists. 105 106 This is to make sure the string composed from each sublist won't exceed 107 80 characters. 108 109 Arguments: 110 - items: a list of strings 111 """ 112 chuncks = [] 113 chunk = [] 114 chunk_len = 0 115 116 for item in items: 117 chunk_len += len(item) + 2 118 if chunk_len > 80: 119 chuncks.append(chunk) 120 chunk = [] 121 chunk_len = len(item) + 2 122 chunk.append(item) 123 124 if len(chunk) != 0: 125 chuncks.append(chunk) 126 127 return chuncks 128 129 130def toposort(dag, sort_fn): 131 """Topologically sorts the given dag. 132 133 Arguments: 134 - dag: a dict mapping from a node to its incoming nodes. 135 - sort_fn: a function for sorting nodes in the same batch. 136 137 Returns: 138 A list containing topologically sorted nodes. 139 """ 140 141 # Returns the next batch of nodes without incoming edges 142 def get_next_batch(dag): 143 while True: 144 no_prev_nodes = set(node for node, prev in dag.items() if not prev) 145 if not no_prev_nodes: 146 break 147 yield sorted(no_prev_nodes, key=sort_fn) 148 dag = { 149 node: (prev - no_prev_nodes) 150 for node, prev in dag.items() 151 if node not in no_prev_nodes 152 } 153 assert not dag, "found cyclic dependency" 154 155 sorted_nodes = [] 156 for batch in get_next_batch(dag): 157 sorted_nodes.extend(batch) 158 159 return sorted_nodes 160 161 162def toposort_capabilities(all_cases): 163 """Returns topologically sorted capability (symbol, value) pairs. 164 165 Arguments: 166 - all_cases: all capability cases (containing symbol, value, and implied 167 capabilities). 168 169 Returns: 170 A list containing topologically sorted capability (symbol, value) pairs. 171 """ 172 dag = {} 173 name_to_value = {} 174 for case in all_cases: 175 # Get the current capability. 176 cur = case["enumerant"] 177 name_to_value[cur] = case["value"] 178 179 # Get capabilities implied by the current capability. 180 prev = case.get("capabilities", []) 181 uniqued_prev = set(prev) 182 dag[cur] = uniqued_prev 183 184 sorted_caps = toposort(dag, lambda x: name_to_value[x]) 185 # Attach the capability's value as the second component of the pair. 186 return [(c, name_to_value[c]) for c in sorted_caps] 187 188 189def get_availability_spec(enum_case, for_op, for_cap): 190 """Returns the availability specification string for the given enum case. 191 192 Arguments: 193 - enum_case: the enum case to generate availability spec for. It may contain 194 'version', 'lastVersion', 'extensions', or 'capabilities'. 195 - for_op: bool value indicating whether this is the availability spec for an 196 op itself. 197 - for_cap: bool value indicating whether this is the availability spec for 198 capabilities themselves. 199 200 Returns: 201 - A `let availability = [...];` string if with availability spec or 202 empty string if without availability spec 203 """ 204 assert not (for_op and for_cap), "cannot set both for_op and for_cap" 205 206 DEFAULT_MIN_VERSION = "MinVersion<SPIRV_V_1_0>" 207 DEFAULT_MAX_VERSION = "MaxVersion<SPIRV_V_1_6>" 208 DEFAULT_CAP = "Capability<[]>" 209 DEFAULT_EXT = "Extension<[]>" 210 211 min_version = enum_case.get("version", "") 212 if min_version == "None": 213 min_version = "" 214 elif min_version: 215 min_version = "MinVersion<SPIRV_V_{}>".format(min_version.replace(".", "_")) 216 # TODO: delete this once ODS can support dialect-specific content 217 # and we can use omission to mean no requirements. 218 if for_op and not min_version: 219 min_version = DEFAULT_MIN_VERSION 220 221 max_version = enum_case.get("lastVersion", "") 222 if max_version: 223 max_version = "MaxVersion<SPIRV_V_{}>".format(max_version.replace(".", "_")) 224 # TODO: delete this once ODS can support dialect-specific content 225 # and we can use omission to mean no requirements. 226 if for_op and not max_version: 227 max_version = DEFAULT_MAX_VERSION 228 229 exts = enum_case.get("extensions", []) 230 if exts: 231 exts = "Extension<[{}]>".format(", ".join(sorted(set(exts)))) 232 # We need to strip the minimal version requirement if this symbol is 233 # available via an extension, which means *any* SPIR-V version can support 234 # it as long as the extension is provided. The grammar's 'version' field 235 # under such case should be interpreted as this symbol is introduced as 236 # a core symbol since the given version, rather than a minimal version 237 # requirement. 238 min_version = DEFAULT_MIN_VERSION if for_op else "" 239 # TODO: delete this once ODS can support dialect-specific content 240 # and we can use omission to mean no requirements. 241 if for_op and not exts: 242 exts = DEFAULT_EXT 243 244 caps = enum_case.get("capabilities", []) 245 implies = "" 246 if caps: 247 canonicalized_caps = [] 248 for c in caps: 249 canonicalized_caps.append(c) 250 prefixed_caps = [ 251 "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps)) 252 ] 253 if for_cap: 254 # If this is generating the availability for capabilities, we need to 255 # put the capability "requirements" in implies field because now 256 # the "capabilities" field in the source grammar means so. 257 caps = "" 258 implies = "list<I32EnumAttrCase> implies = [{}];".format( 259 ", ".join(prefixed_caps) 260 ) 261 else: 262 caps = "Capability<[{}]>".format(", ".join(prefixed_caps)) 263 implies = "" 264 # TODO: delete this once ODS can support dialect-specific content 265 # and we can use omission to mean no requirements. 266 if for_op and not caps: 267 caps = DEFAULT_CAP 268 269 avail = "" 270 # Compose availability spec if any of the requirements is not empty. 271 # For ops, because we have a default in SPIRV_Op class, omit if the spec 272 # is the same. 273 if (min_version or max_version or caps or exts) and not ( 274 for_op 275 and min_version == DEFAULT_MIN_VERSION 276 and max_version == DEFAULT_MAX_VERSION 277 and caps == DEFAULT_CAP 278 and exts == DEFAULT_EXT 279 ): 280 joined_spec = ",\n ".join( 281 [e for e in [min_version, max_version, exts, caps] if e] 282 ) 283 avail = "{} availability = [\n {}\n ];".format( 284 "let" if for_op else "list<Availability>", joined_spec 285 ) 286 287 return "{}{}{}".format(implies, "\n " if implies and avail else "", avail) 288 289 290def gen_operand_kind_enum_attr(operand_kind): 291 """Generates the TableGen EnumAttr definition for the given operand kind. 292 293 Returns: 294 - The operand kind's name 295 - A string containing the TableGen EnumAttr definition 296 """ 297 if "enumerants" not in operand_kind: 298 return "", "" 299 300 # Returns a symbol for the given case in the given kind. This function 301 # handles Dim specially to avoid having numbers as the start of symbols, 302 # which does not play well with C++ and the MLIR parser. 303 def get_case_symbol(kind_name, case_name): 304 if kind_name == "Dim": 305 if case_name == "1D" or case_name == "2D" or case_name == "3D": 306 return "Dim{}".format(case_name) 307 return case_name 308 309 kind_name = operand_kind["kind"] 310 is_bit_enum = operand_kind["category"] == "BitEnum" 311 kind_acronym = "".join([c for c in kind_name if c >= "A" and c <= "Z"]) 312 313 name_to_case_dict = {} 314 for case in operand_kind["enumerants"]: 315 name_to_case_dict[case["enumerant"]] = case 316 317 if kind_name == "Capability": 318 # Special treatment for capability cases: we need to sort them topologically 319 # because a capability can refer to another via the 'implies' field. 320 kind_cases = toposort_capabilities( 321 operand_kind["enumerants"] 322 ) 323 else: 324 kind_cases = [ 325 (case["enumerant"], case["value"]) for case in operand_kind["enumerants"] 326 ] 327 max_len = max([len(symbol) for (symbol, _) in kind_cases]) 328 329 # Generate the definition for each enum case 330 case_category = "I32Bit" if is_bit_enum else "I32" 331 fmt_str = ( 332 "def SPIRV_{acronym}_{case_name} {colon:>{offset}} " 333 '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}' 334 ) 335 case_defs = [] 336 for case_pair in kind_cases: 337 name = case_pair[0] 338 if is_bit_enum: 339 value = int(case_pair[1], base=16) 340 else: 341 value = int(case_pair[1]) 342 avail = get_availability_spec( 343 name_to_case_dict[name], 344 False, 345 kind_name == "Capability", 346 ) 347 if is_bit_enum: 348 if value == 0: 349 suffix = "None" 350 value = "" 351 else: 352 suffix = "Bit" 353 value = ", {}".format(int(math.log2(value))) 354 else: 355 suffix = "" 356 value = ", {}".format(value) 357 358 case_def = fmt_str.format( 359 category=case_category, 360 suffix=suffix, 361 acronym=kind_acronym, 362 case_name=name, 363 symbol=get_case_symbol(kind_name, name), 364 case_value_part=value, 365 avail=" {{\n {}\n}}".format(avail) if avail else ";", 366 colon=":", 367 offset=(max_len + 1 - len(name)), 368 ) 369 case_defs.append(case_def) 370 case_defs = "\n".join(case_defs) 371 372 # Generate the list of enum case names 373 fmt_str = "SPIRV_{acronym}_{symbol}" 374 case_names = [ 375 fmt_str.format(acronym=kind_acronym, symbol=case[0]) for case in kind_cases 376 ] 377 378 # Split them into sublists and concatenate into multiple lines 379 case_names = split_list_into_sublists(case_names) 380 case_names = ["{:6}".format("") + ", ".join(sublist) for sublist in case_names] 381 case_names = ",\n".join(case_names) 382 383 # Generate the enum attribute definition 384 kind_category = "Bit" if is_bit_enum else "I32" 385 enum_attr = """def SPIRV_{name}Attr : 386 SPIRV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [ 387{cases} 388 ]>;""".format( 389 name=kind_name, 390 snake_name=snake_casify(kind_name), 391 category=kind_category, 392 cases=case_names, 393 ) 394 return kind_name, case_defs + "\n\n" + enum_attr 395 396 397def gen_opcode(instructions): 398 """Generates the TableGen definition to map opname to opcode 399 400 Returns: 401 - A string containing the TableGen SPIRV_OpCode definition 402 """ 403 404 max_len = max([len(inst["opname"]) for inst in instructions]) 405 def_fmt_str = ( 406 "def SPIRV_OC_{name} {colon:>{offset}} " 'I32EnumAttrCase<"{name}", {value}>;' 407 ) 408 opcode_defs = [ 409 def_fmt_str.format( 410 name=inst["opname"], 411 value=inst["opcode"], 412 colon=":", 413 offset=(max_len + 1 - len(inst["opname"])), 414 ) 415 for inst in instructions 416 ] 417 opcode_str = "\n".join(opcode_defs) 418 419 decl_fmt_str = "SPIRV_OC_{name}" 420 opcode_list = [decl_fmt_str.format(name=inst["opname"]) for inst in instructions] 421 opcode_list = split_list_into_sublists(opcode_list) 422 opcode_list = ["{:6}".format("") + ", ".join(sublist) for sublist in opcode_list] 423 opcode_list = ",\n".join(opcode_list) 424 enum_attr = ( 425 "def SPIRV_OpcodeAttr :\n" 426 ' SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", ' 427 '"opcode", [\n' 428 "{lst}\n" 429 " ]>;".format(name="Opcode", lst=opcode_list) 430 ) 431 return opcode_str + "\n\n" + enum_attr 432 433 434def map_cap_to_opnames(instructions): 435 """Maps capabilities to instructions enabled by those capabilities 436 437 Arguments: 438 - instructions: a list containing a subset of SPIR-V instructions' grammar 439 Returns: 440 - A map with keys representing capabilities and values of lists of 441 instructions enabled by the corresponding key 442 """ 443 cap_to_inst = {} 444 445 for inst in instructions: 446 caps = inst["capabilities"] if "capabilities" in inst else ["0_core_0"] 447 for cap in caps: 448 if cap not in cap_to_inst: 449 cap_to_inst[cap] = [] 450 cap_to_inst[cap].append(inst["opname"]) 451 452 return cap_to_inst 453 454 455def gen_instr_coverage_report(path, instructions): 456 """Dumps to standard output a YAML report of current instruction coverage 457 458 Arguments: 459 - path: the path to SPIRBase.td 460 - instructions: a list containing all SPIR-V instructions' grammar 461 """ 462 with open(path, "r") as f: 463 content = f.read() 464 465 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) 466 467 prefix = "def SPIRV_OC_" 468 existing_opcodes = [ 469 k[len(prefix) :] for k in re.findall(prefix + r"\w+", content[1]) 470 ] 471 existing_instructions = list( 472 filter(lambda inst: (inst["opname"] in existing_opcodes), instructions) 473 ) 474 475 instructions_opnames = [inst["opname"] for inst in instructions] 476 477 remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes)) 478 remaining_instructions = list( 479 filter(lambda inst: (inst["opname"] in remaining_opcodes), instructions) 480 ) 481 482 rem_cap_to_instr = map_cap_to_opnames(remaining_instructions) 483 ex_cap_to_instr = map_cap_to_opnames(existing_instructions) 484 485 rem_cap_to_cov = {} 486 487 # Calculate coverage for each capability 488 for cap in rem_cap_to_instr: 489 if cap not in ex_cap_to_instr: 490 rem_cap_to_cov[cap] = 0.0 491 else: 492 rem_cap_to_cov[cap] = len(ex_cap_to_instr[cap]) / ( 493 len(ex_cap_to_instr[cap]) + len(rem_cap_to_instr[cap]) 494 ) 495 496 report = {} 497 498 # Merge the 3 maps into one report 499 for cap in rem_cap_to_instr: 500 report[cap] = {} 501 report[cap]["Supported Instructions"] = ( 502 ex_cap_to_instr[cap] if cap in ex_cap_to_instr else [] 503 ) 504 report[cap]["Unsupported Instructions"] = rem_cap_to_instr[cap] 505 report[cap]["Coverage"] = "{}%".format(int(rem_cap_to_cov[cap] * 100)) 506 507 print(yaml.dump(report)) 508 509 510def update_td_opcodes(path, instructions, filter_list): 511 """Updates SPIRBase.td with new generated opcode cases. 512 513 Arguments: 514 - path: the path to SPIRBase.td 515 - instructions: a list containing all SPIR-V instructions' grammar 516 - filter_list: a list containing new opnames to add 517 """ 518 519 with open(path, "r") as f: 520 content = f.read() 521 522 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) 523 assert len(content) == 3 524 525 # Extend opcode list with existing list 526 prefix = "def SPIRV_OC_" 527 existing_opcodes = [ 528 k[len(prefix) :] for k in re.findall(prefix + r"\w+", content[1]) 529 ] 530 filter_list.extend(existing_opcodes) 531 filter_list = list(set(filter_list)) 532 533 # Generate the opcode for all instructions in SPIR-V 534 filter_instrs = list( 535 filter(lambda inst: (inst["opname"] in filter_list), instructions) 536 ) 537 # Sort instruction based on opcode 538 filter_instrs.sort(key=lambda inst: inst["opcode"]) 539 opcode = gen_opcode(filter_instrs) 540 541 # Substitute the opcode 542 content = ( 543 content[0] 544 + AUTOGEN_OPCODE_SECTION_MARKER 545 + "\n\n" 546 + opcode 547 + "\n\n// End " 548 + AUTOGEN_OPCODE_SECTION_MARKER 549 + content[2] 550 ) 551 552 with open(path, "w") as f: 553 f.write(content) 554 555 556def update_td_enum_attrs(path, operand_kinds, filter_list): 557 """Updates SPIRBase.td with new generated enum definitions. 558 559 Arguments: 560 - path: the path to SPIRBase.td 561 - operand_kinds: a list containing all operand kinds' grammar 562 - filter_list: a list containing new enums to add 563 """ 564 with open(path, "r") as f: 565 content = f.read() 566 567 content = content.split(AUTOGEN_ENUM_SECTION_MARKER) 568 assert len(content) == 3 569 570 # Extend filter list with existing enum definitions 571 prefix = "def SPIRV_" 572 suffix = "Attr" 573 existing_kinds = [ 574 k[len(prefix) : -len(suffix)] 575 for k in re.findall(prefix + r"\w+" + suffix, content[1]) 576 ] 577 filter_list.extend(existing_kinds) 578 579 # Generate definitions for all enums in filter list 580 defs = [ 581 gen_operand_kind_enum_attr(kind) 582 for kind in operand_kinds 583 if kind["kind"] in filter_list 584 ] 585 # Sort alphabetically according to enum name 586 defs.sort(key=lambda enum: enum[0]) 587 # Only keep the definitions from now on 588 # Put Capability's definition at the very beginning because capability cases 589 # will be referenced later 590 defs = [enum[1] for enum in defs if enum[0] == "Capability"] + [ 591 enum[1] for enum in defs if enum[0] != "Capability" 592 ] 593 594 # Substitute the old section 595 content = ( 596 content[0] 597 + AUTOGEN_ENUM_SECTION_MARKER 598 + "\n\n" 599 + "\n\n".join(defs) 600 + "\n\n// End " 601 + AUTOGEN_ENUM_SECTION_MARKER 602 + content[2] 603 ) 604 605 with open(path, "w") as f: 606 f.write(content) 607 608 609def snake_casify(name): 610 """Turns the given name to follow snake_case convention.""" 611 return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower() 612 613 614def map_spec_operand_to_ods_argument(operand): 615 """Maps an operand in SPIR-V JSON spec to an op argument in ODS. 616 617 Arguments: 618 - A dict containing the operand's kind, quantifier, and name 619 620 Returns: 621 - A string containing both the type and name for the argument 622 """ 623 kind = operand["kind"] 624 quantifier = operand.get("quantifier", "") 625 626 # These instruction "operands" are for encoding the results; they should 627 # not be handled here. 628 assert kind != "IdResultType", 'unexpected to handle "IdResultType" kind' 629 assert kind != "IdResult", 'unexpected to handle "IdResult" kind' 630 631 if kind == "IdRef": 632 if quantifier == "": 633 arg_type = "SPIRV_Type" 634 elif quantifier == "?": 635 arg_type = "Optional<SPIRV_Type>" 636 else: 637 arg_type = "Variadic<SPIRV_Type>" 638 elif kind == "IdMemorySemantics" or kind == "IdScope": 639 # TODO: Need to further constrain 'IdMemorySemantics' 640 # and 'IdScope' given that they should be generated from OpConstant. 641 assert quantifier == "", ( 642 "unexpected to have optional/variadic memory " "semantics or scope <id>" 643 ) 644 arg_type = "SPIRV_" + kind[2:] + "Attr" 645 elif kind == "LiteralInteger": 646 if quantifier == "": 647 arg_type = "I32Attr" 648 elif quantifier == "?": 649 arg_type = "OptionalAttr<I32Attr>" 650 else: 651 arg_type = "OptionalAttr<I32ArrayAttr>" 652 elif ( 653 kind == "LiteralString" 654 or kind == "LiteralContextDependentNumber" 655 or kind == "LiteralExtInstInteger" 656 or kind == "LiteralSpecConstantOpInteger" 657 or kind == "PairLiteralIntegerIdRef" 658 or kind == "PairIdRefLiteralInteger" 659 or kind == "PairIdRefIdRef" 660 ): 661 assert False, '"{}" kind unimplemented'.format(kind) 662 else: 663 # The rest are all enum operands that we represent with op attributes. 664 assert quantifier != "*", "unexpected to have variadic enum attribute" 665 arg_type = "SPIRV_{}Attr".format(kind) 666 if quantifier == "?": 667 arg_type = "OptionalAttr<{}>".format(arg_type) 668 669 name = operand.get("name", "") 670 name = snake_casify(name) if name else kind.lower() 671 672 return "{}:${}".format(arg_type, name) 673 674 675def get_description(text, appendix): 676 """Generates the description for the given SPIR-V instruction. 677 678 Arguments: 679 - text: Textual description of the operation as string. 680 - appendix: Additional contents to attach in description as string, 681 includking IR examples, and others. 682 683 Returns: 684 - A string that corresponds to the description of the Tablegen op. 685 """ 686 fmt_str = "{text}\n\n <!-- End of AutoGen section -->\n{appendix}\n " 687 return fmt_str.format(text=text, appendix=appendix) 688 689 690def get_op_definition( 691 instruction, opname, doc, existing_info, settings 692): 693 """Generates the TableGen op definition for the given SPIR-V instruction. 694 695 Arguments: 696 - instruction: the instruction's SPIR-V JSON grammar 697 - doc: the instruction's SPIR-V HTML doc 698 - existing_info: a dict containing potential manually specified sections for 699 this instruction 700 701 Returns: 702 - A string containing the TableGen op definition 703 """ 704 if settings.gen_cl_ops: 705 fmt_str = ( 706 "def SPIRV_{opname}Op : " 707 'SPIRV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > ' 708 "{{\n let summary = {summary};\n\n let description = " 709 "[{{\n{description}}}];{availability}\n" 710 ) 711 else: 712 fmt_str = ( 713 "def SPIRV_{vendor_name}{opname_src}Op : " 714 'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> ' 715 "{{\n let summary = {summary};\n\n let description = " 716 "[{{\n{description}}}];{availability}\n" 717 ) 718 719 vendor_name = "" 720 inst_category = existing_info.get("inst_category", "Op") 721 if inst_category == "Op": 722 fmt_str += ( 723 "\n let arguments = (ins{args});\n\n" " let results = (outs{results});\n" 724 ) 725 elif inst_category.endswith("VendorOp"): 726 vendor_name = inst_category.split("VendorOp")[0].upper() 727 assert len(vendor_name) != 0, "Invalid instruction category" 728 729 fmt_str += "{extras}" "}}\n" 730 731 opname_src = instruction["opname"] 732 if opname.startswith("Op"): 733 opname_src = opname_src[2:] 734 if len(vendor_name) > 0: 735 assert opname_src.endswith( 736 vendor_name 737 ), "op name does not match the instruction category" 738 opname_src = opname_src[: -len(vendor_name)] 739 740 category_args = existing_info.get("category_args", "") 741 742 if "\n" in doc: 743 summary, text = doc.split("\n", 1) 744 else: 745 summary = doc 746 text = "" 747 wrapper = textwrap.TextWrapper( 748 width=76, initial_indent=" ", subsequent_indent=" " 749 ) 750 751 # Format summary. If the summary can fit in the same line, we print it out 752 # as a "-quoted string; otherwise, wrap the lines using "[{...}]". 753 summary = summary.strip() 754 if len(summary) + len(' let summary = "";') <= 80: 755 summary = '"{}"'.format(summary) 756 else: 757 summary = "[{{\n{}\n }}]".format(wrapper.fill(summary)) 758 759 # Wrap text 760 text = text.split("\n") 761 text = [wrapper.fill(line) for line in text if line] 762 text = "\n\n".join(text) 763 764 operands = instruction.get("operands", []) 765 766 # Op availability 767 avail = get_availability_spec(instruction, True, False) 768 if avail: 769 avail = "\n\n {0}".format(avail) 770 771 # Set op's result 772 results = "" 773 if len(operands) > 0 and operands[0]["kind"] == "IdResultType": 774 results = "\n SPIRV_Type:$result\n " 775 operands = operands[1:] 776 if "results" in existing_info: 777 results = existing_info["results"] 778 779 # Ignore the operand standing for the result <id> 780 if len(operands) > 0 and operands[0]["kind"] == "IdResult": 781 operands = operands[1:] 782 783 # Set op' argument 784 arguments = existing_info.get("arguments", None) 785 if arguments is None: 786 arguments = [map_spec_operand_to_ods_argument(o) for o in operands] 787 arguments = ",\n ".join(arguments) 788 if arguments: 789 # Prepend and append whitespace for formatting 790 arguments = "\n {}\n ".format(arguments) 791 792 description = existing_info.get("description", None) 793 if description is None: 794 assembly = ( 795 "\n ```\n" 796 " [TODO]\n" 797 " ```\n\n" 798 " #### Example:\n\n" 799 " ```mlir\n" 800 " [TODO]\n" 801 " ```" 802 ) 803 description = get_description(text, assembly) 804 805 return fmt_str.format( 806 opname=opname, 807 opname_src=opname_src, 808 opcode=instruction["opcode"], 809 category_args=category_args, 810 inst_category=inst_category, 811 vendor_name=vendor_name, 812 traits=existing_info.get("traits", ""), 813 summary=summary, 814 description=description, 815 availability=avail, 816 args=arguments, 817 results=results, 818 extras=existing_info.get("extras", ""), 819 ) 820 821 822def get_string_between(base, start, end): 823 """Extracts a substring with a specified start and end from a string. 824 825 Arguments: 826 - base: string to extract from. 827 - start: string to use as the start of the substring. 828 - end: string to use as the end of the substring. 829 830 Returns: 831 - The substring if found 832 - The part of the base after end of the substring. Is the base string itself 833 if the substring wasnt found. 834 """ 835 split = base.split(start, 1) 836 if len(split) == 2: 837 rest = split[1].split(end, 1) 838 assert len(rest) == 2, ( 839 'cannot find end "{end}" while extracting substring ' 840 "starting with {start}".format(start=start, end=end) 841 ) 842 return rest[0].rstrip(end), rest[1] 843 return "", split[0] 844 845 846def get_string_between_nested(base, start, end): 847 """Extracts a substring with a nested start and end from a string. 848 849 Arguments: 850 - base: string to extract from. 851 - start: string to use as the start of the substring. 852 - end: string to use as the end of the substring. 853 854 Returns: 855 - The substring if found 856 - The part of the base after end of the substring. Is the base string itself 857 if the substring wasn't found. 858 """ 859 split = base.split(start, 1) 860 if len(split) == 2: 861 # Handle nesting delimiters 862 rest = split[1] 863 unmatched_start = 1 864 index = 0 865 while unmatched_start > 0 and index < len(rest): 866 if rest[index:].startswith(end): 867 unmatched_start -= 1 868 if unmatched_start == 0: 869 break 870 index += len(end) 871 elif rest[index:].startswith(start): 872 unmatched_start += 1 873 index += len(start) 874 else: 875 index += 1 876 877 assert index < len(rest), ( 878 'cannot find end "{end}" while extracting substring ' 879 'starting with "{start}"'.format(start=start, end=end) 880 ) 881 return rest[:index], rest[index + len(end) :] 882 return "", split[0] 883 884 885def extract_td_op_info(op_def): 886 """Extracts potentially manually specified sections in op's definition. 887 888 Arguments: - A string containing the op's TableGen definition 889 890 Returns: 891 - A dict containing potential manually specified sections 892 """ 893 # Get opname 894 prefix = "def SPIRV_" 895 suffix = "Op" 896 opname = [ 897 o[len(prefix) : -len(suffix)] 898 for o in re.findall(prefix + r"\w+" + suffix, op_def) 899 ] 900 assert len(opname) == 1, "more than one ops in the same section!" 901 opname = opname[0] 902 903 # Get instruction category 904 prefix = "SPIRV_" 905 inst_category = [ 906 o[len(prefix) :] 907 for o in re.findall(prefix + r"\w+Op\b", op_def.split(":", 1)[1]) 908 ] 909 assert len(inst_category) <= 1, "more than one ops in the same section!" 910 inst_category = inst_category[0] if len(inst_category) == 1 else "Op" 911 912 # Get category_args 913 op_tmpl_params, _ = get_string_between_nested(op_def, "<", ">") 914 opstringname, rest = get_string_between(op_tmpl_params, '"', '"') 915 category_args = rest.split("[", 1)[0] 916 category_args = category_args.rsplit(",", 1)[0] 917 918 # Get traits 919 traits, _ = get_string_between_nested(rest, "[", "]") 920 921 # Get description 922 description, rest = get_string_between(op_def, "let description = [{\n", "}];\n") 923 924 # Get arguments 925 args, rest = get_string_between(rest, " let arguments = (ins", ");\n") 926 927 # Get results 928 results, rest = get_string_between(rest, " let results = (outs", ");\n") 929 930 extras = rest.strip(" }\n") 931 if extras: 932 extras = "\n {}\n".format(extras) 933 934 return { 935 # Prefix with 'Op' to make it consistent with SPIR-V spec 936 "opname": "Op{}".format(opname), 937 "inst_category": inst_category, 938 "category_args": category_args, 939 "traits": traits, 940 "description": description, 941 "arguments": args, 942 "results": results, 943 "extras": extras, 944 } 945 946 947def update_td_op_definitions( 948 path, instructions, docs, filter_list, inst_category, settings 949): 950 """Updates SPIRVOps.td with newly generated op definition. 951 952 Arguments: 953 - path: path to SPIRVOps.td 954 - instructions: SPIR-V JSON grammar for all instructions 955 - docs: SPIR-V HTML doc for all instructions 956 - filter_list: a list containing new opnames to include 957 958 Returns: 959 - A string containing all the TableGen op definitions 960 """ 961 with open(path, "r") as f: 962 content = f.read() 963 964 # Split the file into chunks, each containing one op. 965 ops = content.split(AUTOGEN_OP_DEF_SEPARATOR) 966 header = ops[0] 967 footer = ops[-1] 968 ops = ops[1:-1] 969 970 # For each existing op, extract the manually-written sections out to retain 971 # them when re-generating the ops. Also append the existing ops to filter 972 # list. 973 name_op_map = {} # Map from opname to its existing ODS definition 974 op_info_dict = {} 975 for op in ops: 976 info_dict = extract_td_op_info(op) 977 opname = info_dict["opname"] 978 name_op_map[opname] = op 979 op_info_dict[opname] = info_dict 980 filter_list.append(opname) 981 filter_list = sorted(list(set(filter_list))) 982 983 op_defs = [] 984 985 if settings.gen_cl_ops: 986 fix_opname = lambda src: src.replace("CL", "").lower() 987 else: 988 fix_opname = lambda src: src 989 990 for opname in filter_list: 991 # Find the grammar spec for this op 992 try: 993 fixed_opname = fix_opname(opname) 994 instruction = next( 995 inst for inst in instructions if inst["opname"] == fixed_opname 996 ) 997 998 op_defs.append( 999 get_op_definition( 1000 instruction, 1001 opname, 1002 docs[fixed_opname], 1003 op_info_dict.get(opname, {"inst_category": inst_category}), 1004 settings, 1005 ) 1006 ) 1007 except StopIteration: 1008 # This is an op added by us; use the existing ODS definition. 1009 op_defs.append(name_op_map[opname]) 1010 1011 # Substitute the old op definitions 1012 op_defs = [header] + op_defs + [footer] 1013 content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs) 1014 1015 with open(path, "w") as f: 1016 f.write(content) 1017 1018 1019if __name__ == "__main__": 1020 import argparse 1021 1022 cli_parser = argparse.ArgumentParser( 1023 description="Update SPIR-V dialect definitions using SPIR-V spec" 1024 ) 1025 1026 cli_parser.add_argument( 1027 "--base-td-path", 1028 dest="base_td_path", 1029 type=str, 1030 default=None, 1031 help="Path to SPIRVBase.td", 1032 ) 1033 cli_parser.add_argument( 1034 "--op-td-path", 1035 dest="op_td_path", 1036 type=str, 1037 default=None, 1038 help="Path to SPIRVOps.td", 1039 ) 1040 1041 cli_parser.add_argument( 1042 "--new-enum", 1043 dest="new_enum", 1044 type=str, 1045 default=None, 1046 help="SPIR-V enum to be added to SPIRVBase.td", 1047 ) 1048 cli_parser.add_argument( 1049 "--new-opcodes", 1050 dest="new_opcodes", 1051 type=str, 1052 default=None, 1053 nargs="*", 1054 help="update SPIR-V opcodes in SPIRVBase.td", 1055 ) 1056 cli_parser.add_argument( 1057 "--new-inst", 1058 dest="new_inst", 1059 type=str, 1060 default=None, 1061 nargs="*", 1062 help="SPIR-V instruction to be added to ops file", 1063 ) 1064 cli_parser.add_argument( 1065 "--inst-category", 1066 dest="inst_category", 1067 type=str, 1068 default="Op", 1069 help="SPIR-V instruction category used for choosing " 1070 "the TableGen base class to define this op", 1071 ) 1072 cli_parser.add_argument( 1073 "--gen-cl-ops", 1074 dest="gen_cl_ops", 1075 help="Generate OpenCL Extended Instruction Set op", 1076 action="store_true", 1077 ) 1078 cli_parser.set_defaults(gen_cl_ops=False) 1079 cli_parser.add_argument( 1080 "--gen-inst-coverage", dest="gen_inst_coverage", action="store_true" 1081 ) 1082 cli_parser.set_defaults(gen_inst_coverage=False) 1083 1084 args = cli_parser.parse_args() 1085 1086 if args.gen_cl_ops: 1087 ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL 1088 ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL 1089 else: 1090 ext_html_url = None 1091 ext_json_url = None 1092 1093 operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url) 1094 1095 # Define new enum attr 1096 if args.new_enum is not None: 1097 assert args.base_td_path is not None 1098 filter_list = [args.new_enum] if args.new_enum else [] 1099 update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list) 1100 1101 # Define new opcode 1102 if args.new_opcodes is not None: 1103 assert args.base_td_path is not None 1104 update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) 1105 1106 # Define new op 1107 if args.new_inst is not None: 1108 assert args.op_td_path is not None 1109 docs = get_spirv_doc_from_html_spec(ext_html_url, args) 1110 update_td_op_definitions( 1111 args.op_td_path, 1112 instructions, 1113 docs, 1114 args.new_inst, 1115 args.inst_category, 1116 args, 1117 ) 1118 print("Done. Note that this script just generates a template; ", end="") 1119 print("please read the spec and update traits, arguments, and ", end="") 1120 print("results accordingly.") 1121 1122 if args.gen_inst_coverage: 1123 gen_instr_coverage_report(args.base_td_path, instructions) 1124