xref: /llvm-project/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp (revision 48657bf29b01e95749b5ecd8c7f675c14a7948d1)
1 //===-- CUFGPUToLLVMConversion.cpp ----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Transforms/CUFGPUToLLVMConversion.h"
10 #include "flang/Common/Fortran.h"
11 #include "flang/Optimizer/CodeGen/TypeConverter.h"
12 #include "flang/Optimizer/Support/DataLayout.h"
13 #include "flang/Runtime/CUDA/common.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "llvm/Support/FormatVariadic.h"
20 
21 namespace fir {
22 #define GEN_PASS_DEF_CUFGPUTOLLVMCONVERSION
23 #include "flang/Optimizer/Transforms/Passes.h.inc"
24 } // namespace fir
25 
26 using namespace fir;
27 using namespace mlir;
28 using namespace Fortran::runtime;
29 
30 namespace {
31 
32 static mlir::Value createKernelArgArray(mlir::Location loc,
33                                         mlir::ValueRange operands,
34                                         mlir::PatternRewriter &rewriter) {
35 
36   auto *ctx = rewriter.getContext();
37   llvm::SmallVector<mlir::Type> structTypes(operands.size(), nullptr);
38 
39   for (auto [i, arg] : llvm::enumerate(operands))
40     structTypes[i] = arg.getType();
41 
42   auto structTy = mlir::LLVM::LLVMStructType::getLiteral(ctx, structTypes);
43   auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
44   mlir::Type i32Ty = rewriter.getI32Type();
45   auto zero = rewriter.create<mlir::LLVM::ConstantOp>(
46       loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0));
47   auto one = rewriter.create<mlir::LLVM::ConstantOp>(
48       loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 1));
49   mlir::Value argStruct =
50       rewriter.create<mlir::LLVM::AllocaOp>(loc, ptrTy, structTy, one);
51   auto size = rewriter.create<mlir::LLVM::ConstantOp>(
52       loc, i32Ty, rewriter.getIntegerAttr(i32Ty, structTypes.size()));
53   mlir::Value argArray =
54       rewriter.create<mlir::LLVM::AllocaOp>(loc, ptrTy, ptrTy, size);
55 
56   for (auto [i, arg] : llvm::enumerate(operands)) {
57     auto indice = rewriter.create<mlir::LLVM::ConstantOp>(
58         loc, i32Ty, rewriter.getIntegerAttr(i32Ty, i));
59     mlir::Value structMember = rewriter.create<LLVM::GEPOp>(
60         loc, ptrTy, structTy, argStruct,
61         mlir::ArrayRef<mlir::Value>({zero, indice}));
62     rewriter.create<LLVM::StoreOp>(loc, arg, structMember);
63     mlir::Value arrayMember = rewriter.create<LLVM::GEPOp>(
64         loc, ptrTy, ptrTy, argArray, mlir::ArrayRef<mlir::Value>({indice}));
65     rewriter.create<LLVM::StoreOp>(loc, structMember, arrayMember);
66   }
67   return argArray;
68 }
69 
70 struct GPULaunchKernelConversion
71     : public mlir::ConvertOpToLLVMPattern<mlir::gpu::LaunchFuncOp> {
72   explicit GPULaunchKernelConversion(
73       const fir::LLVMTypeConverter &typeConverter, mlir::PatternBenefit benefit)
74       : mlir::ConvertOpToLLVMPattern<mlir::gpu::LaunchFuncOp>(typeConverter,
75                                                               benefit) {}
76 
77   using OpAdaptor = typename mlir::gpu::LaunchFuncOp::Adaptor;
78 
79   mlir::LogicalResult
80   matchAndRewrite(mlir::gpu::LaunchFuncOp op, OpAdaptor adaptor,
81                   mlir::ConversionPatternRewriter &rewriter) const override {
82     mlir::Location loc = op.getLoc();
83     auto *ctx = rewriter.getContext();
84     mlir::ModuleOp mod = op->getParentOfType<mlir::ModuleOp>();
85     mlir::Value dynamicMemorySize = op.getDynamicSharedMemorySize();
86     mlir::Type i32Ty = rewriter.getI32Type();
87     if (!dynamicMemorySize)
88       dynamicMemorySize = rewriter.create<mlir::LLVM::ConstantOp>(
89           loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0));
90 
91     mlir::Value kernelArgs =
92         createKernelArgArray(loc, adaptor.getKernelOperands(), rewriter);
93 
94     auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
95     auto kernel = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(op.getKernelName());
96     mlir::Value kernelPtr;
97     if (!kernel) {
98       auto funcOp = mod.lookupSymbol<mlir::func::FuncOp>(op.getKernelName());
99       if (!funcOp)
100         return mlir::failure();
101       kernelPtr =
102           rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, funcOp.getName());
103     } else {
104       kernelPtr =
105           rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, kernel.getName());
106     }
107 
108     auto llvmIntPtrType = mlir::IntegerType::get(
109         ctx, this->getTypeConverter()->getPointerBitwidth(0));
110     auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx);
111 
112     mlir::Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrTy);
113 
114     if (op.hasClusterSize()) {
115       auto funcOp = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
116           RTNAME_STRING(CUFLaunchClusterKernel));
117       auto funcTy = mlir::LLVM::LLVMFunctionType::get(
118           voidTy,
119           {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
120            llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
121            llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
122           /*isVarArg=*/false);
123       auto cufLaunchClusterKernel = mlir::SymbolRefAttr::get(
124           mod.getContext(), RTNAME_STRING(CUFLaunchClusterKernel));
125       if (!funcOp) {
126         mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
127         rewriter.setInsertionPointToStart(mod.getBody());
128         auto launchKernelFuncOp = rewriter.create<mlir::LLVM::LLVMFuncOp>(
129             loc, RTNAME_STRING(CUFLaunchClusterKernel), funcTy);
130         launchKernelFuncOp.setVisibility(
131             mlir::SymbolTable::Visibility::Private);
132       }
133       rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
134           op, funcTy, cufLaunchClusterKernel,
135           mlir::ValueRange{kernelPtr, adaptor.getClusterSizeX(),
136                            adaptor.getClusterSizeY(), adaptor.getClusterSizeZ(),
137                            adaptor.getGridSizeX(), adaptor.getGridSizeY(),
138                            adaptor.getGridSizeZ(), adaptor.getBlockSizeX(),
139                            adaptor.getBlockSizeY(), adaptor.getBlockSizeZ(),
140                            dynamicMemorySize, kernelArgs, nullPtr});
141     } else {
142       auto procAttr =
143           op->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName());
144       bool isGridGlobal =
145           procAttr && procAttr.getValue() == cuf::ProcAttribute::GridGlobal;
146       llvm::StringRef fctName = isGridGlobal
147                                     ? RTNAME_STRING(CUFLaunchCooperativeKernel)
148                                     : RTNAME_STRING(CUFLaunchKernel);
149       auto funcOp = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(fctName);
150       auto funcTy = mlir::LLVM::LLVMFunctionType::get(
151           voidTy,
152           {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
153            llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
154           /*isVarArg=*/false);
155       auto cufLaunchKernel =
156           mlir::SymbolRefAttr::get(mod.getContext(), fctName);
157       if (!funcOp) {
158         mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
159         rewriter.setInsertionPointToStart(mod.getBody());
160         auto launchKernelFuncOp =
161             rewriter.create<mlir::LLVM::LLVMFuncOp>(loc, fctName, funcTy);
162         launchKernelFuncOp.setVisibility(
163             mlir::SymbolTable::Visibility::Private);
164       }
165       rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
166           op, funcTy, cufLaunchKernel,
167           mlir::ValueRange{kernelPtr, adaptor.getGridSizeX(),
168                            adaptor.getGridSizeY(), adaptor.getGridSizeZ(),
169                            adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
170                            adaptor.getBlockSizeZ(), dynamicMemorySize,
171                            kernelArgs, nullPtr});
172     }
173 
174     return mlir::success();
175   }
176 };
177 
178 class CUFGPUToLLVMConversion
179     : public fir::impl::CUFGPUToLLVMConversionBase<CUFGPUToLLVMConversion> {
180 public:
181   void runOnOperation() override {
182     auto *ctx = &getContext();
183     mlir::RewritePatternSet patterns(ctx);
184     mlir::ConversionTarget target(*ctx);
185 
186     mlir::Operation *op = getOperation();
187     mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
188     if (!module)
189       return signalPassFailure();
190 
191     std::optional<mlir::DataLayout> dl =
192         fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
193     fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
194                                          /*forceUnifiedTBAATree=*/false, *dl);
195     cuf::populateCUFGPUToLLVMConversionPatterns(typeConverter, patterns);
196     target.addIllegalOp<mlir::gpu::LaunchFuncOp>();
197     target.addLegalDialect<mlir::LLVM::LLVMDialect>();
198     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
199                                                   std::move(patterns)))) {
200       mlir::emitError(mlir::UnknownLoc::get(ctx),
201                       "error in CUF GPU op conversion\n");
202       signalPassFailure();
203     }
204   }
205 };
206 } // namespace
207 
208 void cuf::populateCUFGPUToLLVMConversionPatterns(
209     const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
210     mlir::PatternBenefit benefit) {
211   patterns.add<GPULaunchKernelConversion>(converter, benefit);
212 }
213