xref: /llvm-project/mlir/utils/spirv/gen_spirv_dialect.py (revision 6e75eec866133620dcba956bc7d6dbc554642249)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5# See https://llvm.org/LICENSE.txt for license information.
6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7
8# Script for updating SPIR-V dialect by scraping information from SPIR-V
9# HTML and JSON specs from the Internet.
10#
11# For example, to define the enum attribute for SPIR-V memory model:
12#
13# ./gen_spirv_dialect.py --base-td-path /path/to/SPIRVBase.td \
14#                        --new-enum MemoryModel
15#
16# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
17# SPIR-V enum classes.
18
19import itertools
20import math
21import re
22import requests
23import textwrap
24import yaml
25
26SPIRV_HTML_SPEC_URL = (
27    "https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html"
28)
29SPIRV_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json"
30
31SPIRV_CL_EXT_HTML_SPEC_URL = "https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html"
32SPIRV_CL_EXT_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json"
33
34AUTOGEN_OP_DEF_SEPARATOR = "\n// -----\n\n"
35AUTOGEN_ENUM_SECTION_MARKER = "enum section. Generated from SPIR-V spec; DO NOT MODIFY!"
36AUTOGEN_OPCODE_SECTION_MARKER = (
37    "opcode section. Generated from SPIR-V spec; DO NOT MODIFY!"
38)
39
40
41def get_spirv_doc_from_html_spec(url, settings):
42    """Extracts instruction documentation from SPIR-V HTML spec.
43
44    Returns:
45      - A dict mapping from instruction opcode to documentation.
46    """
47    if url is None:
48        url = SPIRV_HTML_SPEC_URL
49
50    response = requests.get(url)
51    spec = response.content
52
53    from bs4 import BeautifulSoup
54
55    spirv = BeautifulSoup(spec, "html.parser")
56
57    doc = {}
58
59    if settings.gen_cl_ops:
60        section_anchor = spirv.find("h2", {"id": "_binary_form"})
61        for section in section_anchor.parent.find_all("div", {"class": "sect2"}):
62            for table in section.find_all("table"):
63                inst_html = table.tbody.tr.td
64                opname = inst_html.a["id"]
65                # Ignore the first line, which is just the opname.
66                doc[opname] = inst_html.text.split("\n", 1)[1].strip()
67    else:
68        section_anchor = spirv.find("h3", {"id": "_instructions_3"})
69        for section in section_anchor.parent.find_all("div", {"class": "sect3"}):
70            for table in section.find_all("table"):
71                inst_html = table.tbody.tr.td.p
72                opname = inst_html.a["id"]
73                # Ignore the first line, which is just the opname.
74                doc[opname] = inst_html.text.split("\n", 1)[1].strip()
75
76    return doc
77
78
79def get_spirv_grammar_from_json_spec(url):
80    """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
81
82    Returns:
83      - A list containing all operand kinds' grammar
84      - A list containing all instructions' grammar
85    """
86    response = requests.get(SPIRV_JSON_SPEC_URL)
87    spec = response.content
88
89    import json
90
91    spirv = json.loads(spec)
92
93    if url is None:
94        return spirv["operand_kinds"], spirv["instructions"]
95
96    response_ext = requests.get(url)
97    spec_ext = response_ext.content
98    spirv_ext = json.loads(spec_ext)
99
100    return spirv["operand_kinds"], spirv_ext["instructions"]
101
102
103def split_list_into_sublists(items):
104    """Split the list of items into multiple sublists.
105
106    This is to make sure the string composed from each sublist won't exceed
107    80 characters.
108
109    Arguments:
110      - items: a list of strings
111    """
112    chuncks = []
113    chunk = []
114    chunk_len = 0
115
116    for item in items:
117        chunk_len += len(item) + 2
118        if chunk_len > 80:
119            chuncks.append(chunk)
120            chunk = []
121            chunk_len = len(item) + 2
122        chunk.append(item)
123
124    if len(chunk) != 0:
125        chuncks.append(chunk)
126
127    return chuncks
128
129
130def toposort(dag, sort_fn):
131    """Topologically sorts the given dag.
132
133    Arguments:
134      - dag: a dict mapping from a node to its incoming nodes.
135      - sort_fn: a function for sorting nodes in the same batch.
136
137    Returns:
138      A list containing topologically sorted nodes.
139    """
140
141    # Returns the next batch of nodes without incoming edges
142    def get_next_batch(dag):
143        while True:
144            no_prev_nodes = set(node for node, prev in dag.items() if not prev)
145            if not no_prev_nodes:
146                break
147            yield sorted(no_prev_nodes, key=sort_fn)
148            dag = {
149                node: (prev - no_prev_nodes)
150                for node, prev in dag.items()
151                if node not in no_prev_nodes
152            }
153        assert not dag, "found cyclic dependency"
154
155    sorted_nodes = []
156    for batch in get_next_batch(dag):
157        sorted_nodes.extend(batch)
158
159    return sorted_nodes
160
161
162def toposort_capabilities(all_cases):
163    """Returns topologically sorted capability (symbol, value) pairs.
164
165    Arguments:
166      - all_cases: all capability cases (containing symbol, value, and implied
167        capabilities).
168
169    Returns:
170      A list containing topologically sorted capability (symbol, value) pairs.
171    """
172    dag = {}
173    name_to_value = {}
174    for case in all_cases:
175        # Get the current capability.
176        cur = case["enumerant"]
177        name_to_value[cur] = case["value"]
178
179        # Get capabilities implied by the current capability.
180        prev = case.get("capabilities", [])
181        uniqued_prev = set(prev)
182        dag[cur] = uniqued_prev
183
184    sorted_caps = toposort(dag, lambda x: name_to_value[x])
185    # Attach the capability's value as the second component of the pair.
186    return [(c, name_to_value[c]) for c in sorted_caps]
187
188
189def get_availability_spec(enum_case, for_op, for_cap):
190    """Returns the availability specification string for the given enum case.
191
192    Arguments:
193      - enum_case: the enum case to generate availability spec for. It may contain
194        'version', 'lastVersion', 'extensions', or 'capabilities'.
195      - for_op: bool value indicating whether this is the availability spec for an
196        op itself.
197      - for_cap: bool value indicating whether this is the availability spec for
198        capabilities themselves.
199
200    Returns:
201      - A `let availability = [...];` string if with availability spec or
202        empty string if without availability spec
203    """
204    assert not (for_op and for_cap), "cannot set both for_op and for_cap"
205
206    DEFAULT_MIN_VERSION = "MinVersion<SPIRV_V_1_0>"
207    DEFAULT_MAX_VERSION = "MaxVersion<SPIRV_V_1_6>"
208    DEFAULT_CAP = "Capability<[]>"
209    DEFAULT_EXT = "Extension<[]>"
210
211    min_version = enum_case.get("version", "")
212    if min_version == "None":
213        min_version = ""
214    elif min_version:
215        min_version = "MinVersion<SPIRV_V_{}>".format(min_version.replace(".", "_"))
216    # TODO: delete this once ODS can support dialect-specific content
217    # and we can use omission to mean no requirements.
218    if for_op and not min_version:
219        min_version = DEFAULT_MIN_VERSION
220
221    max_version = enum_case.get("lastVersion", "")
222    if max_version:
223        max_version = "MaxVersion<SPIRV_V_{}>".format(max_version.replace(".", "_"))
224    # TODO: delete this once ODS can support dialect-specific content
225    # and we can use omission to mean no requirements.
226    if for_op and not max_version:
227        max_version = DEFAULT_MAX_VERSION
228
229    exts = enum_case.get("extensions", [])
230    if exts:
231        exts = "Extension<[{}]>".format(", ".join(sorted(set(exts))))
232        # We need to strip the minimal version requirement if this symbol is
233        # available via an extension, which means *any* SPIR-V version can support
234        # it as long as the extension is provided. The grammar's 'version' field
235        # under such case should be interpreted as this symbol is introduced as
236        # a core symbol since the given version, rather than a minimal version
237        # requirement.
238        min_version = DEFAULT_MIN_VERSION if for_op else ""
239    # TODO: delete this once ODS can support dialect-specific content
240    # and we can use omission to mean no requirements.
241    if for_op and not exts:
242        exts = DEFAULT_EXT
243
244    caps = enum_case.get("capabilities", [])
245    implies = ""
246    if caps:
247        canonicalized_caps = []
248        for c in caps:
249            canonicalized_caps.append(c)
250        prefixed_caps = [
251            "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps))
252        ]
253        if for_cap:
254            # If this is generating the availability for capabilities, we need to
255            # put the capability "requirements" in implies field because now
256            # the "capabilities" field in the source grammar means so.
257            caps = ""
258            implies = "list<I32EnumAttrCase> implies = [{}];".format(
259                ", ".join(prefixed_caps)
260            )
261        else:
262            caps = "Capability<[{}]>".format(", ".join(prefixed_caps))
263            implies = ""
264    # TODO: delete this once ODS can support dialect-specific content
265    # and we can use omission to mean no requirements.
266    if for_op and not caps:
267        caps = DEFAULT_CAP
268
269    avail = ""
270    # Compose availability spec if any of the requirements is not empty.
271    # For ops, because we have a default in SPIRV_Op class, omit if the spec
272    # is the same.
273    if (min_version or max_version or caps or exts) and not (
274        for_op
275        and min_version == DEFAULT_MIN_VERSION
276        and max_version == DEFAULT_MAX_VERSION
277        and caps == DEFAULT_CAP
278        and exts == DEFAULT_EXT
279    ):
280        joined_spec = ",\n    ".join(
281            [e for e in [min_version, max_version, exts, caps] if e]
282        )
283        avail = "{} availability = [\n    {}\n  ];".format(
284            "let" if for_op else "list<Availability>", joined_spec
285        )
286
287    return "{}{}{}".format(implies, "\n  " if implies and avail else "", avail)
288
289
290def gen_operand_kind_enum_attr(operand_kind):
291    """Generates the TableGen EnumAttr definition for the given operand kind.
292
293    Returns:
294      - The operand kind's name
295      - A string containing the TableGen EnumAttr definition
296    """
297    if "enumerants" not in operand_kind:
298        return "", ""
299
300    # Returns a symbol for the given case in the given kind. This function
301    # handles Dim specially to avoid having numbers as the start of symbols,
302    # which does not play well with C++ and the MLIR parser.
303    def get_case_symbol(kind_name, case_name):
304        if kind_name == "Dim":
305            if case_name == "1D" or case_name == "2D" or case_name == "3D":
306                return "Dim{}".format(case_name)
307        return case_name
308
309    kind_name = operand_kind["kind"]
310    is_bit_enum = operand_kind["category"] == "BitEnum"
311    kind_acronym = "".join([c for c in kind_name if c >= "A" and c <= "Z"])
312
313    name_to_case_dict = {}
314    for case in operand_kind["enumerants"]:
315        name_to_case_dict[case["enumerant"]] = case
316
317    if kind_name == "Capability":
318        # Special treatment for capability cases: we need to sort them topologically
319        # because a capability can refer to another via the 'implies' field.
320        kind_cases = toposort_capabilities(
321            operand_kind["enumerants"]
322        )
323    else:
324        kind_cases = [
325            (case["enumerant"], case["value"]) for case in operand_kind["enumerants"]
326        ]
327    max_len = max([len(symbol) for (symbol, _) in kind_cases])
328
329    # Generate the definition for each enum case
330    case_category = "I32Bit" if is_bit_enum else "I32"
331    fmt_str = (
332        "def SPIRV_{acronym}_{case_name} {colon:>{offset}} "
333        '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}'
334    )
335    case_defs = []
336    for case_pair in kind_cases:
337        name = case_pair[0]
338        if is_bit_enum:
339            value = int(case_pair[1], base=16)
340        else:
341            value = int(case_pair[1])
342        avail = get_availability_spec(
343            name_to_case_dict[name],
344            False,
345            kind_name == "Capability",
346        )
347        if is_bit_enum:
348            if value == 0:
349                suffix = "None"
350                value = ""
351            else:
352                suffix = "Bit"
353                value = ", {}".format(int(math.log2(value)))
354        else:
355            suffix = ""
356            value = ", {}".format(value)
357
358        case_def = fmt_str.format(
359            category=case_category,
360            suffix=suffix,
361            acronym=kind_acronym,
362            case_name=name,
363            symbol=get_case_symbol(kind_name, name),
364            case_value_part=value,
365            avail=" {{\n  {}\n}}".format(avail) if avail else ";",
366            colon=":",
367            offset=(max_len + 1 - len(name)),
368        )
369        case_defs.append(case_def)
370    case_defs = "\n".join(case_defs)
371
372    # Generate the list of enum case names
373    fmt_str = "SPIRV_{acronym}_{symbol}"
374    case_names = [
375        fmt_str.format(acronym=kind_acronym, symbol=case[0]) for case in kind_cases
376    ]
377
378    # Split them into sublists and concatenate into multiple lines
379    case_names = split_list_into_sublists(case_names)
380    case_names = ["{:6}".format("") + ", ".join(sublist) for sublist in case_names]
381    case_names = ",\n".join(case_names)
382
383    # Generate the enum attribute definition
384    kind_category = "Bit" if is_bit_enum else "I32"
385    enum_attr = """def SPIRV_{name}Attr :
386    SPIRV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [
387{cases}
388    ]>;""".format(
389        name=kind_name,
390        snake_name=snake_casify(kind_name),
391        category=kind_category,
392        cases=case_names,
393    )
394    return kind_name, case_defs + "\n\n" + enum_attr
395
396
397def gen_opcode(instructions):
398    """Generates the TableGen definition to map opname to opcode
399
400    Returns:
401      - A string containing the TableGen SPIRV_OpCode definition
402    """
403
404    max_len = max([len(inst["opname"]) for inst in instructions])
405    def_fmt_str = (
406        "def SPIRV_OC_{name} {colon:>{offset}} " 'I32EnumAttrCase<"{name}", {value}>;'
407    )
408    opcode_defs = [
409        def_fmt_str.format(
410            name=inst["opname"],
411            value=inst["opcode"],
412            colon=":",
413            offset=(max_len + 1 - len(inst["opname"])),
414        )
415        for inst in instructions
416    ]
417    opcode_str = "\n".join(opcode_defs)
418
419    decl_fmt_str = "SPIRV_OC_{name}"
420    opcode_list = [decl_fmt_str.format(name=inst["opname"]) for inst in instructions]
421    opcode_list = split_list_into_sublists(opcode_list)
422    opcode_list = ["{:6}".format("") + ", ".join(sublist) for sublist in opcode_list]
423    opcode_list = ",\n".join(opcode_list)
424    enum_attr = (
425        "def SPIRV_OpcodeAttr :\n"
426        '    SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '
427        '"opcode", [\n'
428        "{lst}\n"
429        "    ]>;".format(name="Opcode", lst=opcode_list)
430    )
431    return opcode_str + "\n\n" + enum_attr
432
433
434def map_cap_to_opnames(instructions):
435    """Maps capabilities to instructions enabled by those capabilities
436
437    Arguments:
438      - instructions: a list containing a subset of SPIR-V instructions' grammar
439    Returns:
440      - A map with keys representing capabilities and values of lists of
441      instructions enabled by the corresponding key
442    """
443    cap_to_inst = {}
444
445    for inst in instructions:
446        caps = inst["capabilities"] if "capabilities" in inst else ["0_core_0"]
447        for cap in caps:
448            if cap not in cap_to_inst:
449                cap_to_inst[cap] = []
450            cap_to_inst[cap].append(inst["opname"])
451
452    return cap_to_inst
453
454
455def gen_instr_coverage_report(path, instructions):
456    """Dumps to standard output a YAML report of current instruction coverage
457
458    Arguments:
459      - path: the path to SPIRBase.td
460      - instructions: a list containing all SPIR-V instructions' grammar
461    """
462    with open(path, "r") as f:
463        content = f.read()
464
465    content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
466
467    prefix = "def SPIRV_OC_"
468    existing_opcodes = [
469        k[len(prefix) :] for k in re.findall(prefix + r"\w+", content[1])
470    ]
471    existing_instructions = list(
472        filter(lambda inst: (inst["opname"] in existing_opcodes), instructions)
473    )
474
475    instructions_opnames = [inst["opname"] for inst in instructions]
476
477    remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
478    remaining_instructions = list(
479        filter(lambda inst: (inst["opname"] in remaining_opcodes), instructions)
480    )
481
482    rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
483    ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
484
485    rem_cap_to_cov = {}
486
487    # Calculate coverage for each capability
488    for cap in rem_cap_to_instr:
489        if cap not in ex_cap_to_instr:
490            rem_cap_to_cov[cap] = 0.0
491        else:
492            rem_cap_to_cov[cap] = len(ex_cap_to_instr[cap]) / (
493                len(ex_cap_to_instr[cap]) + len(rem_cap_to_instr[cap])
494            )
495
496    report = {}
497
498    # Merge the 3 maps into one report
499    for cap in rem_cap_to_instr:
500        report[cap] = {}
501        report[cap]["Supported Instructions"] = (
502            ex_cap_to_instr[cap] if cap in ex_cap_to_instr else []
503        )
504        report[cap]["Unsupported Instructions"] = rem_cap_to_instr[cap]
505        report[cap]["Coverage"] = "{}%".format(int(rem_cap_to_cov[cap] * 100))
506
507    print(yaml.dump(report))
508
509
510def update_td_opcodes(path, instructions, filter_list):
511    """Updates SPIRBase.td with new generated opcode cases.
512
513    Arguments:
514      - path: the path to SPIRBase.td
515      - instructions: a list containing all SPIR-V instructions' grammar
516      - filter_list: a list containing new opnames to add
517    """
518
519    with open(path, "r") as f:
520        content = f.read()
521
522    content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
523    assert len(content) == 3
524
525    # Extend opcode list with existing list
526    prefix = "def SPIRV_OC_"
527    existing_opcodes = [
528        k[len(prefix) :] for k in re.findall(prefix + r"\w+", content[1])
529    ]
530    filter_list.extend(existing_opcodes)
531    filter_list = list(set(filter_list))
532
533    # Generate the opcode for all instructions in SPIR-V
534    filter_instrs = list(
535        filter(lambda inst: (inst["opname"] in filter_list), instructions)
536    )
537    # Sort instruction based on opcode
538    filter_instrs.sort(key=lambda inst: inst["opcode"])
539    opcode = gen_opcode(filter_instrs)
540
541    # Substitute the opcode
542    content = (
543        content[0]
544        + AUTOGEN_OPCODE_SECTION_MARKER
545        + "\n\n"
546        + opcode
547        + "\n\n// End "
548        + AUTOGEN_OPCODE_SECTION_MARKER
549        + content[2]
550    )
551
552    with open(path, "w") as f:
553        f.write(content)
554
555
556def update_td_enum_attrs(path, operand_kinds, filter_list):
557    """Updates SPIRBase.td with new generated enum definitions.
558
559    Arguments:
560      - path: the path to SPIRBase.td
561      - operand_kinds: a list containing all operand kinds' grammar
562      - filter_list: a list containing new enums to add
563    """
564    with open(path, "r") as f:
565        content = f.read()
566
567    content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
568    assert len(content) == 3
569
570    # Extend filter list with existing enum definitions
571    prefix = "def SPIRV_"
572    suffix = "Attr"
573    existing_kinds = [
574        k[len(prefix) : -len(suffix)]
575        for k in re.findall(prefix + r"\w+" + suffix, content[1])
576    ]
577    filter_list.extend(existing_kinds)
578
579    # Generate definitions for all enums in filter list
580    defs = [
581        gen_operand_kind_enum_attr(kind)
582        for kind in operand_kinds
583        if kind["kind"] in filter_list
584    ]
585    # Sort alphabetically according to enum name
586    defs.sort(key=lambda enum: enum[0])
587    # Only keep the definitions from now on
588    # Put Capability's definition at the very beginning because capability cases
589    # will be referenced later
590    defs = [enum[1] for enum in defs if enum[0] == "Capability"] + [
591        enum[1] for enum in defs if enum[0] != "Capability"
592    ]
593
594    # Substitute the old section
595    content = (
596        content[0]
597        + AUTOGEN_ENUM_SECTION_MARKER
598        + "\n\n"
599        + "\n\n".join(defs)
600        + "\n\n// End "
601        + AUTOGEN_ENUM_SECTION_MARKER
602        + content[2]
603    )
604
605    with open(path, "w") as f:
606        f.write(content)
607
608
609def snake_casify(name):
610    """Turns the given name to follow snake_case convention."""
611    return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
612
613
614def map_spec_operand_to_ods_argument(operand):
615    """Maps an operand in SPIR-V JSON spec to an op argument in ODS.
616
617    Arguments:
618      - A dict containing the operand's kind, quantifier, and name
619
620    Returns:
621      - A string containing both the type and name for the argument
622    """
623    kind = operand["kind"]
624    quantifier = operand.get("quantifier", "")
625
626    # These instruction "operands" are for encoding the results; they should
627    # not be handled here.
628    assert kind != "IdResultType", 'unexpected to handle "IdResultType" kind'
629    assert kind != "IdResult", 'unexpected to handle "IdResult" kind'
630
631    if kind == "IdRef":
632        if quantifier == "":
633            arg_type = "SPIRV_Type"
634        elif quantifier == "?":
635            arg_type = "Optional<SPIRV_Type>"
636        else:
637            arg_type = "Variadic<SPIRV_Type>"
638    elif kind == "IdMemorySemantics" or kind == "IdScope":
639        # TODO: Need to further constrain 'IdMemorySemantics'
640        # and 'IdScope' given that they should be generated from OpConstant.
641        assert quantifier == "", (
642            "unexpected to have optional/variadic memory " "semantics or scope <id>"
643        )
644        arg_type = "SPIRV_" + kind[2:] + "Attr"
645    elif kind == "LiteralInteger":
646        if quantifier == "":
647            arg_type = "I32Attr"
648        elif quantifier == "?":
649            arg_type = "OptionalAttr<I32Attr>"
650        else:
651            arg_type = "OptionalAttr<I32ArrayAttr>"
652    elif (
653        kind == "LiteralString"
654        or kind == "LiteralContextDependentNumber"
655        or kind == "LiteralExtInstInteger"
656        or kind == "LiteralSpecConstantOpInteger"
657        or kind == "PairLiteralIntegerIdRef"
658        or kind == "PairIdRefLiteralInteger"
659        or kind == "PairIdRefIdRef"
660    ):
661        assert False, '"{}" kind unimplemented'.format(kind)
662    else:
663        # The rest are all enum operands that we represent with op attributes.
664        assert quantifier != "*", "unexpected to have variadic enum attribute"
665        arg_type = "SPIRV_{}Attr".format(kind)
666        if quantifier == "?":
667            arg_type = "OptionalAttr<{}>".format(arg_type)
668
669    name = operand.get("name", "")
670    name = snake_casify(name) if name else kind.lower()
671
672    return "{}:${}".format(arg_type, name)
673
674
675def get_description(text, appendix):
676    """Generates the description for the given SPIR-V instruction.
677
678    Arguments:
679      - text: Textual description of the operation as string.
680      - appendix: Additional contents to attach in description as string,
681                  includking IR examples, and others.
682
683    Returns:
684      - A string that corresponds to the description of the Tablegen op.
685    """
686    fmt_str = "{text}\n\n    <!-- End of AutoGen section -->\n{appendix}\n  "
687    return fmt_str.format(text=text, appendix=appendix)
688
689
690def get_op_definition(
691    instruction, opname, doc, existing_info, settings
692):
693    """Generates the TableGen op definition for the given SPIR-V instruction.
694
695    Arguments:
696      - instruction: the instruction's SPIR-V JSON grammar
697      - doc: the instruction's SPIR-V HTML doc
698      - existing_info: a dict containing potential manually specified sections for
699        this instruction
700
701    Returns:
702      - A string containing the TableGen op definition
703    """
704    if settings.gen_cl_ops:
705        fmt_str = (
706            "def SPIRV_{opname}Op : "
707            'SPIRV_{inst_category}<"{opname_src}", {opcode}, <<Insert result type>> > '
708            "{{\n  let summary = {summary};\n\n  let description = "
709            "[{{\n{description}}}];{availability}\n"
710        )
711    else:
712        fmt_str = (
713            "def SPIRV_{vendor_name}{opname_src}Op : "
714            'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> '
715            "{{\n  let summary = {summary};\n\n  let description = "
716            "[{{\n{description}}}];{availability}\n"
717        )
718
719    vendor_name = ""
720    inst_category = existing_info.get("inst_category", "Op")
721    if inst_category == "Op":
722        fmt_str += (
723            "\n  let arguments = (ins{args});\n\n" "  let results = (outs{results});\n"
724        )
725    elif inst_category.endswith("VendorOp"):
726        vendor_name = inst_category.split("VendorOp")[0].upper()
727        assert len(vendor_name) != 0, "Invalid instruction category"
728
729    fmt_str += "{extras}" "}}\n"
730
731    opname_src = instruction["opname"]
732    if opname.startswith("Op"):
733        opname_src = opname_src[2:]
734    if len(vendor_name) > 0:
735        assert opname_src.endswith(
736            vendor_name
737        ), "op name does not match the instruction category"
738        opname_src = opname_src[: -len(vendor_name)]
739
740    category_args = existing_info.get("category_args", "")
741
742    if "\n" in doc:
743        summary, text = doc.split("\n", 1)
744    else:
745        summary = doc
746        text = ""
747    wrapper = textwrap.TextWrapper(
748        width=76, initial_indent="    ", subsequent_indent="    "
749    )
750
751    # Format summary. If the summary can fit in the same line, we print it out
752    # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
753    summary = summary.strip()
754    if len(summary) + len('  let summary = "";') <= 80:
755        summary = '"{}"'.format(summary)
756    else:
757        summary = "[{{\n{}\n  }}]".format(wrapper.fill(summary))
758
759    # Wrap text
760    text = text.split("\n")
761    text = [wrapper.fill(line) for line in text if line]
762    text = "\n\n".join(text)
763
764    operands = instruction.get("operands", [])
765
766    # Op availability
767    avail = get_availability_spec(instruction, True, False)
768    if avail:
769        avail = "\n\n  {0}".format(avail)
770
771    # Set op's result
772    results = ""
773    if len(operands) > 0 and operands[0]["kind"] == "IdResultType":
774        results = "\n    SPIRV_Type:$result\n  "
775        operands = operands[1:]
776    if "results" in existing_info:
777        results = existing_info["results"]
778
779    # Ignore the operand standing for the result <id>
780    if len(operands) > 0 and operands[0]["kind"] == "IdResult":
781        operands = operands[1:]
782
783    # Set op' argument
784    arguments = existing_info.get("arguments", None)
785    if arguments is None:
786        arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
787        arguments = ",\n    ".join(arguments)
788        if arguments:
789            # Prepend and append whitespace for formatting
790            arguments = "\n    {}\n  ".format(arguments)
791
792    description = existing_info.get("description", None)
793    if description is None:
794        assembly = (
795            "\n    ```\n"
796            "    [TODO]\n"
797            "    ```\n\n"
798            "    #### Example:\n\n"
799            "    ```mlir\n"
800            "    [TODO]\n"
801            "    ```"
802        )
803        description = get_description(text, assembly)
804
805    return fmt_str.format(
806        opname=opname,
807        opname_src=opname_src,
808        opcode=instruction["opcode"],
809        category_args=category_args,
810        inst_category=inst_category,
811        vendor_name=vendor_name,
812        traits=existing_info.get("traits", ""),
813        summary=summary,
814        description=description,
815        availability=avail,
816        args=arguments,
817        results=results,
818        extras=existing_info.get("extras", ""),
819    )
820
821
822def get_string_between(base, start, end):
823    """Extracts a substring with a specified start and end from a string.
824
825    Arguments:
826      - base: string to extract from.
827      - start: string to use as the start of the substring.
828      - end: string to use as the end of the substring.
829
830    Returns:
831      - The substring if found
832      - The part of the base after end of the substring. Is the base string itself
833        if the substring wasnt found.
834    """
835    split = base.split(start, 1)
836    if len(split) == 2:
837        rest = split[1].split(end, 1)
838        assert len(rest) == 2, (
839            'cannot find end "{end}" while extracting substring '
840            "starting with {start}".format(start=start, end=end)
841        )
842        return rest[0].rstrip(end), rest[1]
843    return "", split[0]
844
845
846def get_string_between_nested(base, start, end):
847    """Extracts a substring with a nested start and end from a string.
848
849    Arguments:
850      - base: string to extract from.
851      - start: string to use as the start of the substring.
852      - end: string to use as the end of the substring.
853
854    Returns:
855      - The substring if found
856      - The part of the base after end of the substring. Is the base string itself
857        if the substring wasn't found.
858    """
859    split = base.split(start, 1)
860    if len(split) == 2:
861        # Handle nesting delimiters
862        rest = split[1]
863        unmatched_start = 1
864        index = 0
865        while unmatched_start > 0 and index < len(rest):
866            if rest[index:].startswith(end):
867                unmatched_start -= 1
868                if unmatched_start == 0:
869                    break
870                index += len(end)
871            elif rest[index:].startswith(start):
872                unmatched_start += 1
873                index += len(start)
874            else:
875                index += 1
876
877        assert index < len(rest), (
878            'cannot find end "{end}" while extracting substring '
879            'starting with "{start}"'.format(start=start, end=end)
880        )
881        return rest[:index], rest[index + len(end) :]
882    return "", split[0]
883
884
885def extract_td_op_info(op_def):
886    """Extracts potentially manually specified sections in op's definition.
887
888    Arguments: - A string containing the op's TableGen definition
889
890    Returns:
891      - A dict containing potential manually specified sections
892    """
893    # Get opname
894    prefix = "def SPIRV_"
895    suffix = "Op"
896    opname = [
897        o[len(prefix) : -len(suffix)]
898        for o in re.findall(prefix + r"\w+" + suffix, op_def)
899    ]
900    assert len(opname) == 1, "more than one ops in the same section!"
901    opname = opname[0]
902
903    # Get instruction category
904    prefix = "SPIRV_"
905    inst_category = [
906        o[len(prefix) :]
907        for o in re.findall(prefix + r"\w+Op\b", op_def.split(":", 1)[1])
908    ]
909    assert len(inst_category) <= 1, "more than one ops in the same section!"
910    inst_category = inst_category[0] if len(inst_category) == 1 else "Op"
911
912    # Get category_args
913    op_tmpl_params, _ = get_string_between_nested(op_def, "<", ">")
914    opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
915    category_args = rest.split("[", 1)[0]
916    category_args = category_args.rsplit(",", 1)[0]
917
918    # Get traits
919    traits, _ = get_string_between_nested(rest, "[", "]")
920
921    # Get description
922    description, rest = get_string_between(op_def, "let description = [{\n", "}];\n")
923
924    # Get arguments
925    args, rest = get_string_between(rest, "  let arguments = (ins", ");\n")
926
927    # Get results
928    results, rest = get_string_between(rest, "  let results = (outs", ");\n")
929
930    extras = rest.strip(" }\n")
931    if extras:
932        extras = "\n  {}\n".format(extras)
933
934    return {
935        # Prefix with 'Op' to make it consistent with SPIR-V spec
936        "opname": "Op{}".format(opname),
937        "inst_category": inst_category,
938        "category_args": category_args,
939        "traits": traits,
940        "description": description,
941        "arguments": args,
942        "results": results,
943        "extras": extras,
944    }
945
946
947def update_td_op_definitions(
948    path, instructions, docs, filter_list, inst_category, settings
949):
950    """Updates SPIRVOps.td with newly generated op definition.
951
952    Arguments:
953      - path: path to SPIRVOps.td
954      - instructions: SPIR-V JSON grammar for all instructions
955      - docs: SPIR-V HTML doc for all instructions
956      - filter_list: a list containing new opnames to include
957
958    Returns:
959      - A string containing all the TableGen op definitions
960    """
961    with open(path, "r") as f:
962        content = f.read()
963
964    # Split the file into chunks, each containing one op.
965    ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
966    header = ops[0]
967    footer = ops[-1]
968    ops = ops[1:-1]
969
970    # For each existing op, extract the manually-written sections out to retain
971    # them when re-generating the ops. Also append the existing ops to filter
972    # list.
973    name_op_map = {}  # Map from opname to its existing ODS definition
974    op_info_dict = {}
975    for op in ops:
976        info_dict = extract_td_op_info(op)
977        opname = info_dict["opname"]
978        name_op_map[opname] = op
979        op_info_dict[opname] = info_dict
980        filter_list.append(opname)
981    filter_list = sorted(list(set(filter_list)))
982
983    op_defs = []
984
985    if settings.gen_cl_ops:
986        fix_opname = lambda src: src.replace("CL", "").lower()
987    else:
988        fix_opname = lambda src: src
989
990    for opname in filter_list:
991        # Find the grammar spec for this op
992        try:
993            fixed_opname = fix_opname(opname)
994            instruction = next(
995                inst for inst in instructions if inst["opname"] == fixed_opname
996            )
997
998            op_defs.append(
999                get_op_definition(
1000                    instruction,
1001                    opname,
1002                    docs[fixed_opname],
1003                    op_info_dict.get(opname, {"inst_category": inst_category}),
1004                    settings,
1005                )
1006            )
1007        except StopIteration:
1008            # This is an op added by us; use the existing ODS definition.
1009            op_defs.append(name_op_map[opname])
1010
1011    # Substitute the old op definitions
1012    op_defs = [header] + op_defs + [footer]
1013    content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
1014
1015    with open(path, "w") as f:
1016        f.write(content)
1017
1018
1019if __name__ == "__main__":
1020    import argparse
1021
1022    cli_parser = argparse.ArgumentParser(
1023        description="Update SPIR-V dialect definitions using SPIR-V spec"
1024    )
1025
1026    cli_parser.add_argument(
1027        "--base-td-path",
1028        dest="base_td_path",
1029        type=str,
1030        default=None,
1031        help="Path to SPIRVBase.td",
1032    )
1033    cli_parser.add_argument(
1034        "--op-td-path",
1035        dest="op_td_path",
1036        type=str,
1037        default=None,
1038        help="Path to SPIRVOps.td",
1039    )
1040
1041    cli_parser.add_argument(
1042        "--new-enum",
1043        dest="new_enum",
1044        type=str,
1045        default=None,
1046        help="SPIR-V enum to be added to SPIRVBase.td",
1047    )
1048    cli_parser.add_argument(
1049        "--new-opcodes",
1050        dest="new_opcodes",
1051        type=str,
1052        default=None,
1053        nargs="*",
1054        help="update SPIR-V opcodes in SPIRVBase.td",
1055    )
1056    cli_parser.add_argument(
1057        "--new-inst",
1058        dest="new_inst",
1059        type=str,
1060        default=None,
1061        nargs="*",
1062        help="SPIR-V instruction to be added to ops file",
1063    )
1064    cli_parser.add_argument(
1065        "--inst-category",
1066        dest="inst_category",
1067        type=str,
1068        default="Op",
1069        help="SPIR-V instruction category used for choosing "
1070        "the TableGen base class to define this op",
1071    )
1072    cli_parser.add_argument(
1073        "--gen-cl-ops",
1074        dest="gen_cl_ops",
1075        help="Generate OpenCL Extended Instruction Set op",
1076        action="store_true",
1077    )
1078    cli_parser.set_defaults(gen_cl_ops=False)
1079    cli_parser.add_argument(
1080        "--gen-inst-coverage", dest="gen_inst_coverage", action="store_true"
1081    )
1082    cli_parser.set_defaults(gen_inst_coverage=False)
1083
1084    args = cli_parser.parse_args()
1085
1086    if args.gen_cl_ops:
1087        ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL
1088        ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL
1089    else:
1090        ext_html_url = None
1091        ext_json_url = None
1092
1093    operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url)
1094
1095    # Define new enum attr
1096    if args.new_enum is not None:
1097        assert args.base_td_path is not None
1098        filter_list = [args.new_enum] if args.new_enum else []
1099        update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
1100
1101    # Define new opcode
1102    if args.new_opcodes is not None:
1103        assert args.base_td_path is not None
1104        update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
1105
1106    # Define new op
1107    if args.new_inst is not None:
1108        assert args.op_td_path is not None
1109        docs = get_spirv_doc_from_html_spec(ext_html_url, args)
1110        update_td_op_definitions(
1111            args.op_td_path,
1112            instructions,
1113            docs,
1114            args.new_inst,
1115            args.inst_category,
1116            args,
1117        )
1118        print("Done. Note that this script just generates a template; ", end="")
1119        print("please read the spec and update traits, arguments, and ", end="")
1120        print("results accordingly.")
1121
1122    if args.gen_inst_coverage:
1123        gen_instr_coverage_report(args.base_td_path, instructions)
1124