1 //===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops 10 // automatically. It is used by NVVM to LLVM pass. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h" 15 16 #define DEBUG_TYPE "ptx-builder" 17 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 18 #define DBGSNL() (llvm::dbgs() << "\n") 19 20 //===----------------------------------------------------------------------===// 21 // BasicPtxBuilderInterface 22 //===----------------------------------------------------------------------===// 23 24 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.cpp.inc" 25 26 using namespace mlir; 27 using namespace NVVM; 28 29 static constexpr int64_t kSharedMemorySpace = 3; 30 31 static char getRegisterType(Type type) { 32 if (type.isInteger(1)) 33 return 'b'; 34 if (type.isInteger(16)) 35 return 'h'; 36 if (type.isInteger(32)) 37 return 'r'; 38 if (type.isInteger(64)) 39 return 'l'; 40 if (type.isF32()) 41 return 'f'; 42 if (type.isF64()) 43 return 'd'; 44 if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) { 45 // Shared address spaces is addressed with 32-bit pointers. 46 if (ptr.getAddressSpace() == kSharedMemorySpace) { 47 return 'r'; 48 } 49 return 'l'; 50 } 51 // register type for struct is not supported. 52 llvm_unreachable("The register type could not deduced from MLIR type"); 53 return '?'; 54 } 55 56 static char getRegisterType(Value v) { 57 if (v.getDefiningOp<LLVM::ConstantOp>()) 58 return 'n'; 59 return getRegisterType(v.getType()); 60 } 61 62 void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { 63 LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n"); 64 auto getModifier = [&]() -> const char * { 65 if (itype == PTXRegisterMod::ReadWrite) { 66 assert(false && "Read-Write modifier is not supported. Try setting the " 67 "same value as Write and Read separately."); 68 return "+"; 69 } 70 if (itype == PTXRegisterMod::Write) { 71 return "="; 72 } 73 return ""; 74 }; 75 auto addValue = [&](Value v) { 76 if (itype == PTXRegisterMod::Read) { 77 ptxOperands.push_back(v); 78 return; 79 } 80 if (itype == PTXRegisterMod::ReadWrite) 81 ptxOperands.push_back(v); 82 hasResult = true; 83 }; 84 85 llvm::raw_string_ostream ss(registerConstraints); 86 // Handle Structs 87 if (auto stype = dyn_cast<LLVM::LLVMStructType>(v.getType())) { 88 if (itype == PTXRegisterMod::Write) { 89 addValue(v); 90 } 91 for (auto [idx, t] : llvm::enumerate(stype.getBody())) { 92 if (itype != PTXRegisterMod::Write) { 93 Value extractValue = rewriter.create<LLVM::ExtractValueOp>( 94 interfaceOp->getLoc(), v, idx); 95 addValue(extractValue); 96 } 97 if (itype == PTXRegisterMod::ReadWrite) { 98 ss << idx << ","; 99 } else { 100 ss << getModifier() << getRegisterType(t) << ","; 101 } 102 } 103 return; 104 } 105 // Handle Scalars 106 addValue(v); 107 ss << getModifier() << getRegisterType(v) << ","; 108 } 109 110 LLVM::InlineAsmOp PtxBuilder::build() { 111 auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), 112 LLVM::AsmDialect::AD_ATT); 113 114 auto resultTypes = interfaceOp->getResultTypes(); 115 116 // Remove the last comma from the constraints string. 117 if (!registerConstraints.empty() && 118 registerConstraints[registerConstraints.size() - 1] == ',') 119 registerConstraints.pop_back(); 120 121 std::string ptxInstruction = interfaceOp.getPtx(); 122 123 // Add the predicate to the asm string. 124 if (interfaceOp.getPredicate().has_value() && 125 interfaceOp.getPredicate().value()) { 126 std::string predicateStr = "@%"; 127 predicateStr += std::to_string((ptxOperands.size() - 1)); 128 ptxInstruction = predicateStr + " " + ptxInstruction; 129 } 130 131 // Tablegen doesn't accept $, so we use %, but inline assembly uses $. 132 // Replace all % with $ 133 std::replace(ptxInstruction.begin(), ptxInstruction.end(), '%', '$'); 134 135 return rewriter.create<LLVM::InlineAsmOp>( 136 interfaceOp->getLoc(), 137 /*result types=*/resultTypes, 138 /*operands=*/ptxOperands, 139 /*asm_string=*/llvm::StringRef(ptxInstruction), 140 /*constraints=*/registerConstraints.data(), 141 /*has_side_effects=*/interfaceOp.hasSideEffect(), 142 /*is_align_stack=*/false, 143 /*asm_dialect=*/asmDialectAttr, 144 /*operand_attrs=*/ArrayAttr()); 145 } 146 147 void PtxBuilder::buildAndReplaceOp() { 148 LLVM::InlineAsmOp inlineAsmOp = build(); 149 LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n"); 150 if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) { 151 rewriter.replaceOp(interfaceOp, inlineAsmOp); 152 } else { 153 rewriter.eraseOp(interfaceOp); 154 } 155 } 156