xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp (revision 19992eea2300f1e97a71013b50f582538cad5022)
1 //===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10 // automatically. It is used by NVVM to LLVM pass.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
15 
16 #define DEBUG_TYPE "ptx-builder"
17 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
18 #define DBGSNL() (llvm::dbgs() << "\n")
19 
20 //===----------------------------------------------------------------------===//
21 // BasicPtxBuilderInterface
22 //===----------------------------------------------------------------------===//
23 
24 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc"
25 
26 using namespace mlir;
27 using namespace NVVM;
28 
29 static constexpr int64_t kSharedMemorySpace = 3;
30 
31 static char getRegisterType(Type type) {
32   if (type.isInteger(1))
33     return 'b';
34   if (type.isInteger(16))
35     return 'h';
36   if (type.isInteger(32))
37     return 'r';
38   if (type.isInteger(64))
39     return 'l';
40   if (type.isF32())
41     return 'f';
42   if (type.isF64())
43     return 'd';
44   if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
45     // Shared address spaces is addressed with 32-bit pointers.
46     if (ptr.getAddressSpace() == kSharedMemorySpace) {
47       return 'r';
48     }
49     return 'l';
50   }
51   // register type for struct is not supported.
52   llvm_unreachable("The register type could not deduced from MLIR type");
53   return '?';
54 }
55 
56 static char getRegisterType(Value v) {
57   if (v.getDefiningOp<LLVM::ConstantOp>())
58     return 'n';
59   return getRegisterType(v.getType());
60 }
61 
62 void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
63   LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
64   auto getModifier = [&]() -> const char * {
65     if (itype == PTXRegisterMod::ReadWrite) {
66       assert(false && "Read-Write modifier is not supported. Try setting the "
67                       "same value as Write and Read separately.");
68       return "+";
69     }
70     if (itype == PTXRegisterMod::Write) {
71       return "=";
72     }
73     return "";
74   };
75   auto addValue = [&](Value v) {
76     if (itype == PTXRegisterMod::Read) {
77       ptxOperands.push_back(v);
78       return;
79     }
80     if (itype == PTXRegisterMod::ReadWrite)
81       ptxOperands.push_back(v);
82     hasResult = true;
83   };
84 
85   llvm::raw_string_ostream ss(registerConstraints);
86   // Handle Structs
87   if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) {
88     if (itype == PTXRegisterMod::Write) {
89       addValue(v);
90     }
91     for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
92       if (itype != PTXRegisterMod::Write) {
93         Value extractValue = rewriter.create<LLVM::ExtractValueOp>(
94             interfaceOp->getLoc(), v, idx);
95         addValue(extractValue);
96       }
97       if (itype == PTXRegisterMod::ReadWrite) {
98         ss << idx << ",";
99       } else {
100         ss << getModifier() << getRegisterType(t) << ",";
101       }
102     }
103     return;
104   }
105   // Handle Scalars
106   addValue(v);
107   ss << getModifier() << getRegisterType(v) << ",";
108 }
109 
110 LLVM::InlineAsmOp PtxBuilder::build() {
111   auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
112                                                   LLVM::AsmDialect::AD_ATT);
113 
114   auto resultTypes = interfaceOp->getResultTypes();
115 
116   // Remove the last comma from the constraints string.
117   if (!registerConstraints.empty() &&
118       registerConstraints[registerConstraints.size() - 1] == ',')
119     registerConstraints.pop_back();
120 
121   std::string ptxInstruction = interfaceOp.getPtx();
122 
123   // Add the predicate to the asm string.
124   if (interfaceOp.getPredicate().has_value() &&
125       interfaceOp.getPredicate().value()) {
126     std::string predicateStr = "@%";
127     predicateStr += std::to_string((ptxOperands.size() - 1));
128     ptxInstruction = predicateStr + " " + ptxInstruction;
129   }
130 
131   // Tablegen doesn't accept $, so we use %, but inline assembly uses $.
132   // Replace all % with $
133   std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$');
134 
135   return rewriter.create<LLVM::InlineAsmOp>(
136       interfaceOp->getLoc(),
137       /*result types=*/resultTypes,
138       /*operands=*/ptxOperands,
139       /*asm_string=*/llvm::StringRef(ptxInstruction),
140       /*constraints=*/registerConstraints.data(),
141       /*has_side_effects=*/interfaceOp.hasSideEffect(),
142       /*is_align_stack=*/false,
143       /*asm_dialect=*/asmDialectAttr,
144       /*operand_attrs=*/ArrayAttr());
145 }
146 
147 void PtxBuilder::buildAndReplaceOp() {
148   LLVM::InlineAsmOp inlineAsmOp = build();
149   LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
150   if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
151     rewriter.replaceOp(interfaceOp, inlineAsmOp);
152   } else {
153     rewriter.eraseOp(interfaceOp);
154   }
155 }
156