xref: /llvm-project/utils/bazel/llvm-project-overlay/mlir/tblgen.bzl (revision 1b232fa0e9864dde230db8da82a906c588baf792)
1# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
2# See https://llvm.org/LICENSE.txt for license information.
3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4"""BUILD extensions for MLIR table generation."""
5
6load("@bazel_skylib//lib:paths.bzl", "paths")
7
8TdInfo = provider(
9    "Holds TableGen files and the dependencies and include paths necessary to" +
10    " build them.",
11    fields = {
12        "transitive_sources": "td files transitively used by this rule.",
13        "transitive_includes": (
14            "include arguments to add to the final TableGen invocation. These" +
15            " are the absolute directory paths that will be added with '-I'."
16        ),
17    },
18)
19
20# For now we allow anything that provides DefaultInfo to just forward its files.
21# In particular, this allows filegroups to be used. This is mostly to ease
22# transition. In the future, the TdInfo provider will be required.
23# TODO(gcmn): Switch to enforcing TdInfo provider.
24def _get_dep_transitive_srcs(dep):
25    """Extract TdInfo.transitive_sources, falling back to DefaultInfo.files."""
26    if TdInfo in dep:
27        return dep[TdInfo].transitive_sources
28    return dep[DefaultInfo].files
29
30def _get_dep_transitive_includes(dep):
31    """Extract TdInfo.transitive_includes, falling back to an empty depset()."""
32    if TdInfo in dep:
33        return dep[TdInfo].transitive_includes
34    return depset()
35
36def _get_transitive_srcs(srcs, deps):
37    """Obtain the source files for a target and its transitive dependencies.
38
39    Args:
40      srcs: a list of source files
41      deps: a list of targets that are direct dependencies
42    Returns:
43      a collection of the transitive sources
44    """
45    return depset(
46        direct = srcs,
47        transitive = [_get_dep_transitive_srcs(dep) for dep in deps],
48    )
49
50def _get_transitive_includes(includes, deps):
51    """Obtain the includes paths for a target and its transitive dependencies.
52
53    Args:
54      includes: a list of include paths
55      deps: a list of targets that are direct dependencies
56    Returns:
57      a collection of the transitive include paths
58    """
59    return depset(
60        direct = includes,
61        transitive = [_get_dep_transitive_includes(dep) for dep in deps],
62    )
63
64def _prefix_roots(ctx, includes):
65    """Map the given includes to be relative to all root directories.
66
67    This will expand them to be relative to all the root directories available
68    in the execution environment for ctx.run (bin and genfiles in addition to
69    the normal source root)
70    """
71    prefixed_includes = []
72    for include in includes:
73        prefixed_includes.append(include)
74        prefixed_includes.append(paths.join(ctx.genfiles_dir.path, include))
75        prefixed_includes.append(paths.join(ctx.bin_dir.path, include))
76    return prefixed_includes
77
78def _resolve_includes(ctx, includes):
79    """Resolves include paths to paths relative to the execution root.
80
81    Relative paths are interpreted as relative to the current label's package.
82    Absolute paths are interpreted as relative to the current label's workspace
83    root."""
84    package = ctx.label.package
85    workspace_root = ctx.label.workspace_root
86    workspace_root = workspace_root if workspace_root else "."
87    resolved_includes = []
88    for include in includes:
89        if paths.is_absolute(include):
90            include = include.lstrip("/")
91        else:
92            include = paths.join(package, include)
93        include = paths.join(workspace_root, include)
94        resolved_includes.extend(_prefix_roots(ctx, [include]))
95    return resolved_includes
96
97def _td_library_impl(ctx):
98    trans_srcs = _get_transitive_srcs(ctx.files.srcs, ctx.attr.deps)
99    trans_includes = _get_transitive_includes(
100        _resolve_includes(ctx, ctx.attr.includes),
101        ctx.attr.deps,
102    )
103
104    # Note that we include srcs in runfiles. A td_library doesn't compile to
105    # produce an output: it's just a depset of source files and include
106    # directories. So if it is needed for execution of some rule (likely
107    # something running tblgen as a test action), the files needed are the same
108    # as the source files.
109    # Note: not using merge_all, as that is not available in Bazel 4.0
110    runfiles = ctx.runfiles(ctx.files.srcs)
111    for src in ctx.attr.srcs:
112        runfiles = runfiles.merge(src[DefaultInfo].default_runfiles)
113    for dep in ctx.attr.deps:
114        runfiles = runfiles.merge(dep[DefaultInfo].default_runfiles)
115
116    return [
117        DefaultInfo(files = trans_srcs, runfiles = runfiles),
118        TdInfo(
119            transitive_sources = trans_srcs,
120            transitive_includes = trans_includes,
121        ),
122    ]
123
124td_library = rule(
125    _td_library_impl,
126    attrs = {
127        "srcs": attr.label_list(allow_files = True),
128        "includes": attr.string_list(
129            doc = "Include paths to be added to the final TableGen tool" +
130                  " invocation. Relative paths are interpreted as relative to" +
131                  " the current label's package. Absolute paths are" +
132                  " interpreted as relative to the current label's workspace",
133        ),
134        # TODO(gcmn): limit to TdInfo providers.
135        "deps": attr.label_list(
136            doc = "Dependencies providing TableGen source files and include" +
137                  " paths.",
138        ),
139    },
140)
141
142def _gentbl_rule_impl(ctx):
143    td_file = ctx.file.td_file
144
145    trans_srcs = _get_transitive_srcs(
146        ctx.files.td_srcs + [td_file],
147        ctx.attr.deps,
148    )
149
150    # Note that the td_file.dirname is already relative to the execution root,
151    # i.e. may contain an `external/<workspace_name>` prefix if the current
152    # workspace is not the main workspace. Therefore it is not included in the
153    # _resolve_includes call that prepends this prefix.
154    trans_includes = _get_transitive_includes(
155        _resolve_includes(ctx, ctx.attr.includes + ["/"]) +
156        _prefix_roots(ctx, [td_file.dirname]),
157        ctx.attr.deps,
158    )
159
160    args = ctx.actions.args()
161    args.add_all(ctx.attr.opts)
162    args.add(td_file)
163    args.add_all(trans_includes, before_each = "-I")
164
165    args.add("-o", ctx.outputs.out.path)
166
167    ctx.actions.run(
168        outputs = [ctx.outputs.out],
169        inputs = trans_srcs,
170        executable = ctx.executable.tblgen,
171        arguments = [args],
172        # Make sure action_env settings are honored so the env is the same as
173        # when the tool was built. Important for locating shared libraries with
174        # a custom LD_LIBRARY_PATH.
175        use_default_shell_env = True,
176        mnemonic = "TdGenerate",
177    )
178
179    return [DefaultInfo()]
180
181gentbl_rule = rule(
182    _gentbl_rule_impl,
183    doc = "Generates tabular code from a table definition file.",
184    attrs = {
185        "tblgen": attr.label(
186            doc = "The TableGen executable with which to generate `out`.",
187            executable = True,
188            cfg = "exec",
189        ),
190        "td_file": attr.label(
191            doc = "The TableGen file to run through `tblgen`.",
192            allow_single_file = True,
193            mandatory = True,
194        ),
195        "td_srcs": attr.label_list(
196            doc = "Additional TableGen files included by `td_file`. It is not" +
197                  " necessary to list td_file here (though not an error).",
198            allow_files = True,
199        ),
200        # TODO(gcmn): limit to TdInfo providers.
201        "deps": attr.label_list(
202            doc = "Dependencies providing TableGen source files and include" +
203                  " paths.",
204        ),
205        "out": attr.output(
206            doc = "The output file for the TableGen invocation.",
207            mandatory = True,
208        ),
209        "opts": attr.string_list(
210            doc = "Additional command line options to add to the TableGen" +
211                  " invocation. For include arguments, prefer to use" +
212                  " `includes`.",
213        ),
214        "includes": attr.string_list(
215            doc = "Include paths to be added to the final TableGen tool" +
216                  " invocation. Relative paths are interpreted as relative to" +
217                  " the current label's package. Absolute paths are" +
218                  " interpreted as relative to the current label's workspace." +
219                  " Includes are applied from all roots available in the" +
220                  " execution environment (source, genfiles, and bin" +
221                  " directories). The execution roots themselves and the " +
222                  " directory of td_file are always added.",
223        ),
224    },
225)
226
227# TODO(gcmn): Figure out how to reduce duplication with _gentbl_rule_impl
228def _gentbl_test_impl(ctx):
229    td_file = ctx.file.td_file
230
231    # Note that the td_file.dirname is already relative to the execution root,
232    # i.e. may contain an `external/<workspace_name>` prefix if the current
233    # workspace is not the main workspace. Therefore it is not included in the
234    # _resolve_includes call that prepends this prefix.
235    trans_includes = _get_transitive_includes(
236        _resolve_includes(ctx, ctx.attr.includes + ["/"]) +
237        _prefix_roots(ctx, [td_file.dirname]),
238        ctx.attr.deps,
239    )
240
241    test_args = [ctx.executable.tblgen.short_path]
242    test_args.extend(ctx.attr.opts)
243    test_args.append(td_file.path)
244    test_args.extend(["-I " + include for include in trans_includes.to_list()])
245
246    test_args.extend(["-o", "/dev/null"])
247
248    ctx.actions.write(
249        ctx.outputs.executable,
250        content = " ".join(test_args),
251        is_executable = True,
252    )
253
254    # Note: not using merge_all, as that is not available in Bazel 4.0
255    runfiles = ctx.runfiles(
256        files = [ctx.executable.tblgen],
257        transitive_files = _get_transitive_srcs(
258            ctx.files.td_srcs + [td_file],
259            ctx.attr.deps,
260        ),
261    )
262    for src in ctx.attr.td_srcs:
263        runfiles = runfiles.merge(src[DefaultInfo].default_runfiles)
264    for dep in ctx.attr.deps:
265        runfiles = runfiles.merge(dep[DefaultInfo].default_runfiles)
266
267    return [
268        coverage_common.instrumented_files_info(
269            ctx,
270            source_attributes = ["td_file", "td_srcs"],
271            dependency_attributes = ["tblgen", "deps"],
272        ),
273        DefaultInfo(runfiles = runfiles),
274    ]
275
276gentbl_test = rule(
277    _gentbl_test_impl,
278    test = True,
279    doc = "A shell test that tests the given TablegGen invocation. Note" +
280          " that unlike gentbl_rule, this builds and invokes `tblgen` in the" +
281          " target configuration. Takes all the same arguments as gentbl_rule" +
282          " except for `out` (as it does not generate any output)",
283    attrs = {
284        "tblgen": attr.label(
285            doc = "The TableGen executable run in the shell command. Note" +
286                  " that this is built in the target configuration.",
287            executable = True,
288            cfg = "target",
289        ),
290        "td_file": attr.label(
291            doc = "See gentbl_rule.td_file",
292            allow_single_file = True,
293            mandatory = True,
294        ),
295        "td_srcs": attr.label_list(
296            doc = "See gentbl_rule.td_srcs",
297            allow_files = True,
298        ),
299        "deps": attr.label_list(doc = "See gentbl_rule.deps"),
300        "opts": attr.string_list(doc = "See gentbl_rule.opts"),
301        "includes": attr.string_list(doc = "See gentbl_rule.includes"),
302    },
303)
304
305def gentbl_filegroup(
306        name,
307        tblgen,
308        td_file,
309        tbl_outs,
310        td_srcs = [],
311        includes = [],
312        deps = [],
313        test = False,
314        skip_opts = [],
315        **kwargs):
316    """Create multiple TableGen generated files using the same tool and input.
317
318    All generated outputs are bundled in a file group with the given name.
319
320    Args:
321      name: The name of the generated filegroup rule for use in dependencies.
322      tblgen: The binary used to produce the output.
323      td_file: The primary table definitions file.
324      tbl_outs: A list of tuples ([opts], out), where each 'opts' is a list of
325        options passed to tblgen, each option being a string, and 'out' is the
326        corresponding output file produced.
327      td_srcs: See gentbl_rule.td_srcs
328      includes: See gentbl_rule.includes
329      deps: See gentbl_rule.deps
330      test: Whether to create a shell test that invokes the tool too.
331      skip_opts: Files generated using these opts in tbl_outs will be excluded
332        from the generated filegroup.
333      **kwargs: Extra keyword arguments to pass to all generated rules.
334    """
335
336    for (opts, out) in tbl_outs:
337        first_opt = opts[0] if opts else ""
338        rule_suffix = "_{}_{}".format(
339            first_opt.replace("-", "_").replace("=", "_"),
340            str(hash(" ".join(opts))),
341        )
342        gentbl_name = "%s_%s_genrule" % (name, rule_suffix)
343        gentbl_rule(
344            name = gentbl_name,
345            td_file = td_file,
346            tblgen = tblgen,
347            opts = opts,
348            td_srcs = td_srcs,
349            deps = deps,
350            includes = includes,
351            out = out,
352            **kwargs
353        )
354
355        if test:
356            # Also run the generator in the target configuration as a test. This
357            # means it gets run with asserts and sanitizers and such when they
358            # are enabled and is counted in coverage.
359            gentbl_test(
360                name = "%s_test" % (gentbl_name,),
361                td_file = td_file,
362                tblgen = tblgen,
363                opts = opts,
364                td_srcs = td_srcs,
365                deps = deps,
366                includes = includes,
367                # Shell files not executable on Windows.
368                # TODO(gcmn): Support windows.
369                tags = ["no_windows"],
370                **kwargs
371            )
372
373    included_srcs = [f for (opts, f) in tbl_outs if not any([skip_opt in opts for skip_opt in skip_opts])]
374    native.filegroup(
375        name = name,
376        srcs = included_srcs,
377        **kwargs
378    )
379
380def gentbl_cc_library(
381        name,
382        tblgen,
383        td_file,
384        tbl_outs,
385        td_srcs = [],
386        includes = [],
387        deps = [],
388        strip_include_prefix = None,
389        test = False,
390        copts = None,
391        **kwargs):
392    """Create multiple TableGen generated files using the same tool and input.
393
394    All generated outputs are bundled in a cc_library rule.
395
396    Args:
397      name: The name of the generated cc_library rule for use in dependencies.
398      tblgen: The binary used to produce the output.
399      td_file: The primary table definitions file.
400      tbl_outs: A list of tuples ([opts], out), where each 'opts' is a list of
401        options passed to tblgen, each option being a string, and 'out' is the
402        corresponding output file produced.
403      td_srcs: See gentbl_rule.td_srcs
404      includes: See gentbl_rule.includes
405      deps: See gentbl_rule.deps
406      strip_include_prefix: attribute to pass through to cc_library.
407      test: whether to create a shell test that invokes the tool too.
408      copts: list of copts to pass to cc_library.
409      **kwargs: Extra keyword arguments to pass to all generated rules.
410    """
411
412    filegroup_name = name + "_filegroup"
413    gentbl_filegroup(
414        name = filegroup_name,
415        tblgen = tblgen,
416        td_file = td_file,
417        tbl_outs = tbl_outs,
418        td_srcs = td_srcs,
419        includes = includes,
420        deps = deps,
421        test = test,
422        skip_opts = ["-gen-op-doc"],
423        **kwargs
424    )
425    native.cc_library(
426        name = name,
427        # strip_include_prefix does not apply to textual_hdrs.
428        # https://github.com/bazelbuild/bazel/issues/12424
429        hdrs = [":" + filegroup_name] if strip_include_prefix else [],
430        strip_include_prefix = strip_include_prefix,
431        textual_hdrs = [":" + filegroup_name],
432        copts = copts,
433        **kwargs
434    )
435
436def _gentbl_shard_impl(ctx):
437    args = ctx.actions.args()
438    args.add(ctx.file.src_file)
439    args.add("-op-shard-index", ctx.attr.index)
440    args.add("-o", ctx.outputs.out.path)
441    ctx.actions.run(
442        outputs = [ctx.outputs.out],
443        inputs = [ctx.file.src_file],
444        executable = ctx.executable.sharder,
445        arguments = [args],
446        use_default_shell_env = True,
447        mnemonic = "ShardGenerate",
448    )
449
450gentbl_shard_rule = rule(
451    _gentbl_shard_impl,
452    doc = "",
453    output_to_genfiles = True,
454    attrs = {
455        "index": attr.int(mandatory = True, doc = ""),
456        "sharder": attr.label(
457            doc = "",
458            executable = True,
459            cfg = "exec",
460        ),
461        "src_file": attr.label(
462            doc = "",
463            allow_single_file = True,
464            mandatory = True,
465        ),
466        "out": attr.output(
467            doc = "",
468            mandatory = True,
469        ),
470    },
471)
472
473def gentbl_sharded_ops(
474        name,
475        tblgen,
476        sharder,
477        td_file,
478        shard_count,
479        src_file,
480        src_out,
481        hdr_out,
482        test = False,
483        includes = [],
484        strip_include_prefix = None,
485        deps = []):
486    """Generate sharded op declarations and definitions.
487
488    This special build rule shards op definitions in a TableGen file and generates multiple copies
489    of a template source file for including and compiling each shard. The rule defines a filegroup
490    consisting of the source shards, the generated source file, and the generated header file.
491
492    Args:
493      name: The name of the filegroup.
494      tblgen: The binary used to produce the output.
495      sharder: The source file sharder to use.
496      td_file: The primary table definitions file.
497      shard_count: The number of op definition shards to produce.
498      src_file: The source file template.
499      src_out: The generated source file.
500      hdr_out: The generated header file.
501      test: Whether this is a test target.
502      includes: See gentbl_rule.includes
503      deps: See gentbl_rule.deps
504      strip_include_prefix: Attribute to pass through to cc_library.
505    """
506    cc_lib_name = name + "__gentbl_cc_lib"
507    gentbl_cc_library(
508        name = cc_lib_name,
509        strip_include_prefix = strip_include_prefix,
510        includes = includes,
511        tbl_outs = [
512            (
513                [
514                    "-gen-op-defs",
515                    "-op-shard-count=" + str(shard_count),
516                ],
517                src_out,
518            ),
519            (
520                [
521                    "-gen-op-decls",
522                    "-op-shard-count=" + str(shard_count),
523                ],
524                hdr_out,
525            ),
526        ],
527        tblgen = tblgen,
528        td_file = td_file,
529        test = test,
530        deps = deps,
531    )
532    all_files = [hdr_out, src_out]
533    for i in range(0, shard_count):
534        out_file = "shard_copy_" + str(i) + "_" + src_file
535        gentbl_shard_rule(
536            index = i,
537            name = name + "__src_shard" + str(i),
538            testonly = test,
539            out = out_file,
540            sharder = sharder,
541            src_file = src_file,
542        )
543        all_files.append(out_file)
544    native.filegroup(name = name, srcs = all_files)
545
546def gentbl_sharded_op_defs(name, source_file, shard_count):
547    """Generates multiple copies of a source file that includes sharded op definitions.
548
549    Args:
550      name: The name of the rule.
551      source_file: The source to copy.
552      shard_count: The number of shards.
553
554    Returns:
555      A list of the copied filenames to be included in the dialect library.
556    """
557    copies = []
558    for i in range(0, shard_count):
559        out_file = "shard_copy_" + str(i) + "_" + source_file
560        copies.append(out_file)
561        native.genrule(
562            name = name + "_shard_" + str(i),
563            srcs = [source_file],
564            outs = [out_file],
565            cmd = "echo -e \"#define GET_OP_DEFS_" + str(i) + "\n$$(cat $(SRCS))\" > $(OUTS)",
566        )
567    return copies
568