xref: /llvm-project/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp (revision ef6014e764825753293fabf65c92afe79ed402cd)
1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements passes to convert `gpu.launch_func` op into a sequence
10 // of LLVM calls that emulate the host and device sides.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
15 
16 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
17 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
18 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
19 #include "mlir/Conversion/LLVMCommon/Pattern.h"
20 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
21 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
22 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/DialectConversion.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
34 
35 namespace mlir {
36 #define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
37 #include "mlir/Conversion/Passes.h.inc"
38 } // namespace mlir
39 
40 using namespace mlir;
41 
42 static constexpr const char kSPIRVModule[] = "__spv__";
43 
44 //===----------------------------------------------------------------------===//
45 // Utility functions
46 //===----------------------------------------------------------------------===//
47 
48 /// Returns the string name of the `DescriptorSet` decoration.
descriptorSetName()49 static std::string descriptorSetName() {
50   return llvm::convertToSnakeFromCamelCase(
51       stringifyDecoration(spirv::Decoration::DescriptorSet));
52 }
53 
54 /// Returns the string name of the `Binding` decoration.
bindingName()55 static std::string bindingName() {
56   return llvm::convertToSnakeFromCamelCase(
57       stringifyDecoration(spirv::Decoration::Binding));
58 }
59 
60 /// Calculates the index of the kernel's operand that is represented by the
61 /// given global variable with the `bind` attribute. We assume that the index of
62 /// each kernel's operand is mapped to (descriptorSet, binding) by the map:
63 ///   i -> (0, i)
64 /// which is implemented under `LowerABIAttributesPass`.
calculateGlobalIndex(spirv::GlobalVariableOp op)65 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
66   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
67   return binding.getInt();
68 }
69 
70 /// Copies the given number of bytes from src to dst pointers.
copy(Location loc,Value dst,Value src,Value size,OpBuilder & builder)71 static void copy(Location loc, Value dst, Value src, Value size,
72                  OpBuilder &builder) {
73   builder.create<LLVM::MemcpyOp>(loc, dst, src, size, /*isVolatile=*/false);
74 }
75 
76 /// Encodes the binding and descriptor set numbers into a new symbolic name.
77 /// The name is specified by
78 ///   {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
79 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
80 /// binding numbers.
81 static std::string
createGlobalVariableWithBindName(spirv::GlobalVariableOp op,StringRef kernelModuleName)82 createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
83                                  StringRef kernelModuleName) {
84   IntegerAttr descriptorSet =
85       op->getAttrOfType<IntegerAttr>(descriptorSetName());
86   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
87   return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
88                        kernelModuleName.str(), op.getSymName().str(),
89                        std::to_string(descriptorSet.getInt()),
90                        std::to_string(binding.getInt()));
91 }
92 
93 /// Returns true if the given global variable has both a descriptor set number
94 /// and a binding number.
hasDescriptorSetAndBinding(spirv::GlobalVariableOp op)95 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
96   IntegerAttr descriptorSet =
97       op->getAttrOfType<IntegerAttr>(descriptorSetName());
98   IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
99   return descriptorSet && binding;
100 }
101 
102 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
103 /// arguments from the given SPIR-V module. We assume that the module contains a
104 /// single entry point function. Hence, all `spirv.GlobalVariable`s with a bind
105 /// attribute are kernel arguments.
getKernelGlobalVariables(spirv::ModuleOp module,DenseMap<uint32_t,spirv::GlobalVariableOp> & globalVariableMap)106 static LogicalResult getKernelGlobalVariables(
107     spirv::ModuleOp module,
108     DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
109   auto entryPoints = module.getOps<spirv::EntryPointOp>();
110   if (!llvm::hasSingleElement(entryPoints)) {
111     return module.emitError(
112         "The module must contain exactly one entry point function");
113   }
114   auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
115   for (auto globalOp : globalVariables) {
116     if (hasDescriptorSetAndBinding(globalOp))
117       globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
118   }
119   return success();
120 }
121 
122 /// Encodes the SPIR-V module's symbolic name into the name of the entry point
123 /// function.
encodeKernelName(spirv::ModuleOp module)124 static LogicalResult encodeKernelName(spirv::ModuleOp module) {
125   StringRef spvModuleName = module.getSymName().value_or(kSPIRVModule);
126   // We already know that the module contains exactly one entry point function
127   // based on `getKernelGlobalVariables()` call. Update this function's name
128   // to:
129   //   {spv_module_name}_{function_name}
130   auto entryPoints = module.getOps<spirv::EntryPointOp>();
131   if (!llvm::hasSingleElement(entryPoints)) {
132     return module.emitError(
133         "The module must contain exactly one entry point function");
134   }
135   spirv::EntryPointOp entryPoint = *entryPoints.begin();
136   StringRef funcName = entryPoint.getFn();
137   auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.getFnAttr());
138   StringAttr newFuncName =
139       StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
140   if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
141     return failure();
142   SymbolTable::setSymbolName(funcOp, newFuncName);
143   return success();
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // Conversion patterns
148 //===----------------------------------------------------------------------===//
149 
150 namespace {
151 
152 /// Structure to group information about the variables being copied.
153 struct CopyInfo {
154   Value dst;
155   Value src;
156   Value size;
157 };
158 
159 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we
160 /// copy the data to the global variable (emulating device side), call the
161 /// kernel as a normal void LLVM function, and copy the data back (emulating the
162 /// host side).
163 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
164   using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
165 
166   LogicalResult
matchAndRewrite(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const167   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
168                   ConversionPatternRewriter &rewriter) const override {
169     auto *op = launchOp.getOperation();
170     MLIRContext *context = rewriter.getContext();
171     auto module = launchOp->getParentOfType<ModuleOp>();
172 
173     // Get the SPIR-V module that represents the gpu kernel module. The module
174     // is named:
175     //   __spv__{kernel_module_name}
176     // based on GPU to SPIR-V conversion.
177     StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
178     std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
179     auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
180         StringAttr::get(context, spvModuleName));
181     if (!spvModule) {
182       return launchOp.emitOpError("SPIR-V kernel module '")
183              << spvModuleName << "' is not found";
184     }
185 
186     // Declare kernel function in the main module so that it later can be linked
187     // with its definition from the kernel module. We know that the kernel
188     // function would have no arguments and the data is passed via global
189     // variables. The name of the kernel will be
190     //   {spv_module_name}_{kernel_function_name}
191     // to avoid symbolic name conflicts.
192     StringRef kernelFuncName = launchOp.getKernelName().getValue();
193     std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
194     auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
195         StringAttr::get(context, newKernelFuncName));
196     if (!kernelFunc) {
197       OpBuilder::InsertionGuard guard(rewriter);
198       rewriter.setInsertionPointToStart(module.getBody());
199       kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
200           rewriter.getUnknownLoc(), newKernelFuncName,
201           LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
202                                       ArrayRef<Type>()));
203       rewriter.setInsertionPoint(launchOp);
204     }
205 
206     // Get all global variables associated with the kernel operands.
207     DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
208     if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
209       return failure();
210 
211     // Traverse kernel operands that were converted to MemRefDescriptors. For
212     // each operand, create a global variable and copy data from operand to it.
213     Location loc = launchOp.getLoc();
214     SmallVector<CopyInfo, 4> copyInfo;
215     auto numKernelOperands = launchOp.getNumKernelOperands();
216     auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
217     for (const auto &operand : llvm::enumerate(kernelOperands)) {
218       // Check if the kernel's operand is a ranked memref.
219       auto memRefType = dyn_cast<MemRefType>(
220           launchOp.getKernelOperand(operand.index()).getType());
221       if (!memRefType)
222         return failure();
223 
224       // Calculate the size of the memref and get the pointer to the allocated
225       // buffer.
226       SmallVector<Value, 4> sizes;
227       SmallVector<Value, 4> strides;
228       Value sizeBytes;
229       getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
230                                sizeBytes);
231       MemRefDescriptor descriptor(operand.value());
232       Value src = descriptor.allocatedPtr(rewriter, loc);
233 
234       // Get the global variable in the SPIR-V module that is associated with
235       // the kernel operand. Construct its new name and create a corresponding
236       // LLVM dialect global variable.
237       spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
238       auto pointeeType =
239           cast<spirv::PointerType>(spirvGlobal.getType()).getPointeeType();
240       auto dstGlobalType = typeConverter->convertType(pointeeType);
241       if (!dstGlobalType)
242         return failure();
243       std::string name =
244           createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
245       // Check if this variable has already been created.
246       auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
247       if (!dstGlobal) {
248         OpBuilder::InsertionGuard guard(rewriter);
249         rewriter.setInsertionPointToStart(module.getBody());
250         dstGlobal = rewriter.create<LLVM::GlobalOp>(
251             loc, dstGlobalType,
252             /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
253             /*alignment=*/0);
254         rewriter.setInsertionPoint(launchOp);
255       }
256 
257       // Copy the data from src operand pointer to dst global variable. Save
258       // src, dst and size so that we can copy data back after emulating the
259       // kernel call.
260       Value dst = rewriter.create<LLVM::AddressOfOp>(
261           loc, typeConverter->convertType(spirvGlobal.getType()),
262           dstGlobal.getSymName());
263       copy(loc, dst, src, sizeBytes, rewriter);
264 
265       CopyInfo info;
266       info.dst = dst;
267       info.src = src;
268       info.size = sizeBytes;
269       copyInfo.push_back(info);
270     }
271     // Create a call to the kernel and copy the data back.
272     rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
273                                               ArrayRef<Value>());
274     for (CopyInfo info : copyInfo)
275       copy(loc, info.src, info.dst, info.size, rewriter);
276     return success();
277   }
278 };
279 
280 class LowerHostCodeToLLVM
281     : public impl::LowerHostCodeToLLVMPassBase<LowerHostCodeToLLVM> {
282 public:
283   using Base::Base;
284 
runOnOperation()285   void runOnOperation() override {
286     ModuleOp module = getOperation();
287 
288     // Erase the GPU module.
289     for (auto gpuModule :
290          llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
291       gpuModule.erase();
292 
293     // Request C wrapper emission.
294     for (auto func : module.getOps<func::FuncOp>()) {
295       func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
296                     UnitAttr::get(&getContext()));
297     }
298 
299     // Specify options to lower to LLVM and pull in the conversion patterns.
300     LowerToLLVMOptions options(module.getContext());
301 
302     auto *context = module.getContext();
303     RewritePatternSet patterns(context);
304     LLVMTypeConverter typeConverter(context, options);
305     mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
306     populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
307     populateFuncToLLVMConversionPatterns(typeConverter, patterns);
308     patterns.add<GPULaunchLowering>(typeConverter);
309 
310     // Pull in SPIR-V type conversion patterns to convert SPIR-V global
311     // variable's type to LLVM dialect type.
312     populateSPIRVToLLVMTypeConversion(typeConverter);
313 
314     ConversionTarget target(*context);
315     target.addLegalDialect<LLVM::LLVMDialect>();
316     if (failed(applyPartialConversion(module, target, std::move(patterns))))
317       signalPassFailure();
318 
319     // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
320     // conflicts.
321     for (auto spvModule : module.getOps<spirv::ModuleOp>()) {
322       if (failed(encodeKernelName(spvModule))) {
323         signalPassFailure();
324         return;
325       }
326     }
327   }
328 };
329 } // namespace
330