xref: /llvm-project/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td (revision 63389326f529fd3e3019f8f8afae662e765a3b72)
1//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10// automatically. It is used by NVVM to LLVM pass.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef BASICPTXBUILDER_OP_INTERFACE
15#define BASICPTXBUILDER_OP_INTERFACE
16
17include "mlir/IR/EnumAttr.td"
18include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
19include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
20
21//===----------------------------------------------------------------------===//
22// Basic PTX Builder Interface
23//===----------------------------------------------------------------------===//
24
25def PtxPredicate : Optional<I1>;
26
27def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
28  let description = [{
29    This interface is used to generate inline assembly with PTX for basic
30    operations. It's utilized in the `convert-nvvm-to-llvm pass` to lower
31    NVVM Ops that implement this interface to PTX (parallel thread execution)
32    using inline assembly Ops. Interface methods play a crucial role in this
33    lowering process.
34
35    Here's an example of an Op with the `BasicPtxBuilderOpInterface`:
36    ```tablegen
37      def NVVM_SpecialOp : NVVM_Op<"special.op",
38          [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
39        Results<(outs LLVM_Type:$res)>,
40        Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
41        ...
42        let extraClassDefinition = [{
43          std::string $cppClass::getPtx() {
44            return std::string("special.op %0, %1, %2;");
45          }
46     } ];
47    ```
48
49    In the above NVVM Op example:
50    ```mlir
51      %0 = nvvm.special.op %1, %2 : !llvm.ptr, i32 -> i32
52    ```
53
54    The `convert-nvvm-to-llvm` pass generates the inline assembly like below.
55    The order of arguments is retained, and the read and write modifiers are
56    set based on the input and result types:
57    ```mlir
58      %0 = llvm.inline_asm
59                has_side_effects
60                asm_dialect =
61                att "special.op %0, %1, %2;", "=r,l,r" %arg0, %arg1
62                : (!llvm.ptr, i32) -> i32
63    ```
64  }];
65  let cppNamespace = "::mlir::NVVM";
66  let methods = [
67    InterfaceMethod<
68        /*desc=*/[{
69          Optional function for setting a predicate, which
70          always returns a `PtxPredicate` value of type i1. If no predicate is
71          provided, the instruction is unguarded; otherwise, it's guarded by the
72          predicate value. The `PtxPredicate` value must always be the last argument.
73          The provided PTX code by `getPtx` should not include the predicate usage.
74          The interface automatically handles predicate usage in the generated
75          PTX code when necessary.
76        }],
77        /*retType=*/"std::optional<::mlir::Value>",
78        /*methodName=*/"getPredicate",
79        /*args=*/(ins),
80        /*methodBody=*/"",
81        /*defaultImplementation=*/"return {};"
82      >,
83    InterfaceMethod<
84        /*desc=*/[{ Returns PTX assembly with operand number. }],
85        /*retType=*/"std::string",
86        /*methodName=*/"getPtx"
87      >,
88    InterfaceMethod<
89        /*desc=*/[{
90          This function indicates whether the operation is supported by LLVM
91          intrinsics. It's particularly useful for operations that have
92          specific cases with LLVM intrinsic support.
93        }],
94        /*retType=*/"bool",
95        /*methodName=*/"hasIntrinsic",
96        /*args=*/(ins),
97        /*methodBody=*/"",
98        /*defaultImplementation=*/"return false;"
99      >,
100    InterfaceMethod<
101        /*desc=*/[{Return whether the operation has memory side effects.}],
102        /*retType=*/"bool",
103        /*methodName=*/"hasSideEffect",
104        /*args=*/(ins),
105        /*methodBody=*/"",
106        /*defaultImplementation=*/"return true;"
107      >,
108
109    InterfaceMethod<
110        /*desc=*/[{Helper function to generate i32 constant value.}],
111        /*retType=*/"::mlir::Value",
112        /*methodName=*/"makeConstantI32",
113        /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
114        /*methodBody=*/"",
115        /*defaultImpl=*/ [{
116            mlir::Operation* op = $_op;
117            return rewriter.create<LLVM::ConstantOp>(
118              op->getLoc(), rewriter.getIntegerType(32), val);
119        }]
120     >,
121     InterfaceMethod<
122         /*desc=*/[{
123            This function supplies the necessary arguments for passing PTX code,
124            following this order:
125             1) Adds results
126             2) Adds operands
127             3) Adds attributes
128          }],
129         /*retType=*/"void",
130         /*methodName=*/"getAsmValues",
131         /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
132         "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
133         /*methodBody=*/"",
134         /*defaultImpl=*/ [{
135           mlir::Operation* op = $_op;
136
137           // Step 1. Add results
138           for (auto val : op->getResults())
139            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
140
141           // Step 2. Add operands
142           for (auto val : op->getOperands())
143            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
144
145           // Step 3. Add attributes
146           for (auto attr : op->getAttrs()) {
147            if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
148             ::mlir::Value val = makeConstantI32(rewriter, intAttr.getInt());
149             asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
150             }
151           }
152         }]
153       >
154  ];
155}
156
157#endif // BASICPTXBUILDER_OP_INTERFACE