1//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8// 9// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops 10// automatically. It is used by NVVM to LLVM pass. 11// 12//===----------------------------------------------------------------------===// 13 14#ifndef BASICPTXBUILDER_OP_INTERFACE 15#define BASICPTXBUILDER_OP_INTERFACE 16 17include "mlir/IR/EnumAttr.td" 18include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" 19include "mlir/Dialect/LLVMIR/LLVMOpBase.td" 20 21//===----------------------------------------------------------------------===// 22// Basic PTX Builder Interface 23//===----------------------------------------------------------------------===// 24 25def PtxPredicate : Optional<I1>; 26 27def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> { 28 let description = [{ 29 This interface is used to generate inline assembly with PTX for basic 30 operations. It's utilized in the `convert-nvvm-to-llvm pass` to lower 31 NVVM Ops that implement this interface to PTX (parallel thread execution) 32 using inline assembly Ops. Interface methods play a crucial role in this 33 lowering process. 34 35 Here's an example of an Op with the `BasicPtxBuilderOpInterface`: 36 ```tablegen 37 def NVVM_SpecialOp : NVVM_Op<"special.op", 38 [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>, 39 Results<(outs LLVM_Type:$res)>, 40 Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> { 41 ... 42 let extraClassDefinition = [{ 43 std::string $cppClass::getPtx() { 44 return std::string("special.op %0, %1, %2;"); 45 } 46 } ]; 47 ``` 48 49 In the above NVVM Op example: 50 ```mlir 51 %0 = nvvm.special.op %1, %2 : !llvm.ptr, i32 -> i32 52 ``` 53 54 The `convert-nvvm-to-llvm` pass generates the inline assembly like below. 55 The order of arguments is retained, and the read and write modifiers are 56 set based on the input and result types: 57 ```mlir 58 %0 = llvm.inline_asm 59 has_side_effects 60 asm_dialect = 61 att "special.op %0, %1, %2;", "=r,l,r" %arg0, %arg1 62 : (!llvm.ptr, i32) -> i32 63 ``` 64 }]; 65 let cppNamespace = "::mlir::NVVM"; 66 let methods = [ 67 InterfaceMethod< 68 /*desc=*/[{ 69 Optional function for setting a predicate, which 70 always returns a `PtxPredicate` value of type i1. If no predicate is 71 provided, the instruction is unguarded; otherwise, it's guarded by the 72 predicate value. The `PtxPredicate` value must always be the last argument. 73 The provided PTX code by `getPtx` should not include the predicate usage. 74 The interface automatically handles predicate usage in the generated 75 PTX code when necessary. 76 }], 77 /*retType=*/"std::optional<::mlir::Value>", 78 /*methodName=*/"getPredicate", 79 /*args=*/(ins), 80 /*methodBody=*/"", 81 /*defaultImplementation=*/"return {};" 82 >, 83 InterfaceMethod< 84 /*desc=*/[{ Returns PTX assembly with operand number. }], 85 /*retType=*/"std::string", 86 /*methodName=*/"getPtx" 87 >, 88 InterfaceMethod< 89 /*desc=*/[{ 90 This function indicates whether the operation is supported by LLVM 91 intrinsics. It's particularly useful for operations that have 92 specific cases with LLVM intrinsic support. 93 }], 94 /*retType=*/"bool", 95 /*methodName=*/"hasIntrinsic", 96 /*args=*/(ins), 97 /*methodBody=*/"", 98 /*defaultImplementation=*/"return false;" 99 >, 100 InterfaceMethod< 101 /*desc=*/[{Return whether the operation has memory side effects.}], 102 /*retType=*/"bool", 103 /*methodName=*/"hasSideEffect", 104 /*args=*/(ins), 105 /*methodBody=*/"", 106 /*defaultImplementation=*/"return true;" 107 >, 108 109 InterfaceMethod< 110 /*desc=*/[{Helper function to generate i32 constant value.}], 111 /*retType=*/"::mlir::Value", 112 /*methodName=*/"makeConstantI32", 113 /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val), 114 /*methodBody=*/"", 115 /*defaultImpl=*/ [{ 116 mlir::Operation* op = $_op; 117 return rewriter.create<LLVM::ConstantOp>( 118 op->getLoc(), rewriter.getIntegerType(32), val); 119 }] 120 >, 121 InterfaceMethod< 122 /*desc=*/[{ 123 This function supplies the necessary arguments for passing PTX code, 124 following this order: 125 1) Adds results 126 2) Adds operands 127 3) Adds attributes 128 }], 129 /*retType=*/"void", 130 /*methodName=*/"getAsmValues", 131 /*args=*/(ins "::mlir::RewriterBase &":$rewriter, 132 "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues), 133 /*methodBody=*/"", 134 /*defaultImpl=*/ [{ 135 mlir::Operation* op = $_op; 136 137 // Step 1. Add results 138 for (auto val : op->getResults()) 139 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write}); 140 141 // Step 2. Add operands 142 for (auto val : op->getOperands()) 143 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read}); 144 145 // Step 3. Add attributes 146 for (auto attr : op->getAttrs()) { 147 if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) { 148 ::mlir::Value val = makeConstantI32(rewriter, intAttr.getInt()); 149 asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read}); 150 } 151 } 152 }] 153 > 154 ]; 155} 156 157#endif // BASICPTXBUILDER_OP_INTERFACE