1 //===-- CUFGPUToLLVMConversion.cpp ----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Transforms/CUFGPUToLLVMConversion.h" 10 #include "flang/Common/Fortran.h" 11 #include "flang/Optimizer/CodeGen/TypeConverter.h" 12 #include "flang/Optimizer/Support/DataLayout.h" 13 #include "flang/Runtime/CUDA/common.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Transforms/DialectConversion.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "llvm/Support/FormatVariadic.h" 20 21 namespace fir { 22 #define GEN_PASS_DEF_CUFGPUTOLLVMCONVERSION 23 #include "flang/Optimizer/Transforms/Passes.h.inc" 24 } // namespace fir 25 26 using namespace fir; 27 using namespace mlir; 28 using namespace Fortran::runtime; 29 30 namespace { 31 32 static mlir::Value createKernelArgArray(mlir::Location loc, 33 mlir::ValueRange operands, 34 mlir::PatternRewriter &rewriter) { 35 36 auto *ctx = rewriter.getContext(); 37 llvm::SmallVector<mlir::Type> structTypes(operands.size(), nullptr); 38 39 for (auto [i, arg] : llvm::enumerate(operands)) 40 structTypes[i] = arg.getType(); 41 42 auto structTy = mlir::LLVM::LLVMStructType::getLiteral(ctx, structTypes); 43 auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); 44 mlir::Type i32Ty = rewriter.getI32Type(); 45 auto zero = rewriter.create<mlir::LLVM::ConstantOp>( 46 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0)); 47 auto one = rewriter.create<mlir::LLVM::ConstantOp>( 48 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 1)); 49 mlir::Value argStruct = 50 rewriter.create<mlir::LLVM::AllocaOp>(loc, ptrTy, structTy, one); 51 auto size = rewriter.create<mlir::LLVM::ConstantOp>( 52 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, structTypes.size())); 53 mlir::Value argArray = 54 rewriter.create<mlir::LLVM::AllocaOp>(loc, ptrTy, ptrTy, size); 55 56 for (auto [i, arg] : llvm::enumerate(operands)) { 57 auto indice = rewriter.create<mlir::LLVM::ConstantOp>( 58 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, i)); 59 mlir::Value structMember = rewriter.create<LLVM::GEPOp>( 60 loc, ptrTy, structTy, argStruct, 61 mlir::ArrayRef<mlir::Value>({zero, indice})); 62 rewriter.create<LLVM::StoreOp>(loc, arg, structMember); 63 mlir::Value arrayMember = rewriter.create<LLVM::GEPOp>( 64 loc, ptrTy, ptrTy, argArray, mlir::ArrayRef<mlir::Value>({indice})); 65 rewriter.create<LLVM::StoreOp>(loc, structMember, arrayMember); 66 } 67 return argArray; 68 } 69 70 struct GPULaunchKernelConversion 71 : public mlir::ConvertOpToLLVMPattern<mlir::gpu::LaunchFuncOp> { 72 explicit GPULaunchKernelConversion( 73 const fir::LLVMTypeConverter &typeConverter, mlir::PatternBenefit benefit) 74 : mlir::ConvertOpToLLVMPattern<mlir::gpu::LaunchFuncOp>(typeConverter, 75 benefit) {} 76 77 using OpAdaptor = typename mlir::gpu::LaunchFuncOp::Adaptor; 78 79 mlir::LogicalResult 80 matchAndRewrite(mlir::gpu::LaunchFuncOp op, OpAdaptor adaptor, 81 mlir::ConversionPatternRewriter &rewriter) const override { 82 mlir::Location loc = op.getLoc(); 83 auto *ctx = rewriter.getContext(); 84 mlir::ModuleOp mod = op->getParentOfType<mlir::ModuleOp>(); 85 mlir::Value dynamicMemorySize = op.getDynamicSharedMemorySize(); 86 mlir::Type i32Ty = rewriter.getI32Type(); 87 if (!dynamicMemorySize) 88 dynamicMemorySize = rewriter.create<mlir::LLVM::ConstantOp>( 89 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0)); 90 91 mlir::Value kernelArgs = 92 createKernelArgArray(loc, adaptor.getKernelOperands(), rewriter); 93 94 auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); 95 auto kernel = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(op.getKernelName()); 96 mlir::Value kernelPtr; 97 if (!kernel) { 98 auto funcOp = mod.lookupSymbol<mlir::func::FuncOp>(op.getKernelName()); 99 if (!funcOp) 100 return mlir::failure(); 101 kernelPtr = 102 rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, funcOp.getName()); 103 } else { 104 kernelPtr = 105 rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, kernel.getName()); 106 } 107 108 auto llvmIntPtrType = mlir::IntegerType::get( 109 ctx, this->getTypeConverter()->getPointerBitwidth(0)); 110 auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx); 111 112 mlir::Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrTy); 113 114 if (op.hasClusterSize()) { 115 auto funcOp = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>( 116 RTNAME_STRING(CUFLaunchClusterKernel)); 117 auto funcTy = mlir::LLVM::LLVMFunctionType::get( 118 voidTy, 119 {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, 120 llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, 121 llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy}, 122 /*isVarArg=*/false); 123 auto cufLaunchClusterKernel = mlir::SymbolRefAttr::get( 124 mod.getContext(), RTNAME_STRING(CUFLaunchClusterKernel)); 125 if (!funcOp) { 126 mlir::OpBuilder::InsertionGuard insertGuard(rewriter); 127 rewriter.setInsertionPointToStart(mod.getBody()); 128 auto launchKernelFuncOp = rewriter.create<mlir::LLVM::LLVMFuncOp>( 129 loc, RTNAME_STRING(CUFLaunchClusterKernel), funcTy); 130 launchKernelFuncOp.setVisibility( 131 mlir::SymbolTable::Visibility::Private); 132 } 133 rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( 134 op, funcTy, cufLaunchClusterKernel, 135 mlir::ValueRange{kernelPtr, adaptor.getClusterSizeX(), 136 adaptor.getClusterSizeY(), adaptor.getClusterSizeZ(), 137 adaptor.getGridSizeX(), adaptor.getGridSizeY(), 138 adaptor.getGridSizeZ(), adaptor.getBlockSizeX(), 139 adaptor.getBlockSizeY(), adaptor.getBlockSizeZ(), 140 dynamicMemorySize, kernelArgs, nullPtr}); 141 } else { 142 auto procAttr = 143 op->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName()); 144 bool isGridGlobal = 145 procAttr && procAttr.getValue() == cuf::ProcAttribute::GridGlobal; 146 llvm::StringRef fctName = isGridGlobal 147 ? RTNAME_STRING(CUFLaunchCooperativeKernel) 148 : RTNAME_STRING(CUFLaunchKernel); 149 auto funcOp = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(fctName); 150 auto funcTy = mlir::LLVM::LLVMFunctionType::get( 151 voidTy, 152 {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, 153 llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy}, 154 /*isVarArg=*/false); 155 auto cufLaunchKernel = 156 mlir::SymbolRefAttr::get(mod.getContext(), fctName); 157 if (!funcOp) { 158 mlir::OpBuilder::InsertionGuard insertGuard(rewriter); 159 rewriter.setInsertionPointToStart(mod.getBody()); 160 auto launchKernelFuncOp = 161 rewriter.create<mlir::LLVM::LLVMFuncOp>(loc, fctName, funcTy); 162 launchKernelFuncOp.setVisibility( 163 mlir::SymbolTable::Visibility::Private); 164 } 165 rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( 166 op, funcTy, cufLaunchKernel, 167 mlir::ValueRange{kernelPtr, adaptor.getGridSizeX(), 168 adaptor.getGridSizeY(), adaptor.getGridSizeZ(), 169 adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), 170 adaptor.getBlockSizeZ(), dynamicMemorySize, 171 kernelArgs, nullPtr}); 172 } 173 174 return mlir::success(); 175 } 176 }; 177 178 class CUFGPUToLLVMConversion 179 : public fir::impl::CUFGPUToLLVMConversionBase<CUFGPUToLLVMConversion> { 180 public: 181 void runOnOperation() override { 182 auto *ctx = &getContext(); 183 mlir::RewritePatternSet patterns(ctx); 184 mlir::ConversionTarget target(*ctx); 185 186 mlir::Operation *op = getOperation(); 187 mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op); 188 if (!module) 189 return signalPassFailure(); 190 191 std::optional<mlir::DataLayout> dl = 192 fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false); 193 fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false, 194 /*forceUnifiedTBAATree=*/false, *dl); 195 cuf::populateCUFGPUToLLVMConversionPatterns(typeConverter, patterns); 196 target.addIllegalOp<mlir::gpu::LaunchFuncOp>(); 197 target.addLegalDialect<mlir::LLVM::LLVMDialect>(); 198 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, 199 std::move(patterns)))) { 200 mlir::emitError(mlir::UnknownLoc::get(ctx), 201 "error in CUF GPU op conversion\n"); 202 signalPassFailure(); 203 } 204 } 205 }; 206 } // namespace 207 208 void cuf::populateCUFGPUToLLVMConversionPatterns( 209 const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns, 210 mlir::PatternBenefit benefit) { 211 patterns.add<GPULaunchKernelConversion>(converter, benefit); 212 } 213