xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp (revision 69d757c0e8ffc5b49fda10df38e470a56d616ef4)
1 //===- KernelOutlining.cpp - Implementation of GPU kernel outlining -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the GPU dialect kernel outlining pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/GPU/GPUDialect.h"
14 #include "mlir/Dialect/GPU/Passes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/BlockAndValueMapping.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/SymbolTable.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/RegionUtils.h"
21 
22 using namespace mlir;
23 
24 template <typename OpTy>
25 static void createForAllDimensions(OpBuilder &builder, Location loc,
26                                    SmallVectorImpl<Value> &values) {
27   for (StringRef dim : {"x", "y", "z"}) {
28     Value v = builder.create<OpTy>(loc, builder.getIndexType(),
29                                    builder.getStringAttr(dim));
30     values.push_back(v);
31   }
32 }
33 
34 // Add operations generating block/thread ids and grid/block dimensions at the
35 // beginning of the `body` region and replace uses of the respective function
36 // arguments.
37 static void injectGpuIndexOperations(Location loc, Region &body) {
38   OpBuilder builder(loc->getContext());
39   Block &firstBlock = body.front();
40   builder.setInsertionPointToStart(&firstBlock);
41   SmallVector<Value, 12> indexOps;
42   createForAllDimensions<gpu::BlockIdOp>(builder, loc, indexOps);
43   createForAllDimensions<gpu::ThreadIdOp>(builder, loc, indexOps);
44   createForAllDimensions<gpu::GridDimOp>(builder, loc, indexOps);
45   createForAllDimensions<gpu::BlockDimOp>(builder, loc, indexOps);
46   // Replace the leading 12 function args with the respective thread/block index
47   // operations. Iterate backwards since args are erased and indices change.
48   for (int i = 11; i >= 0; --i) {
49     firstBlock.getArgument(i).replaceAllUsesWith(indexOps[i]);
50     firstBlock.eraseArgument(i);
51   }
52 }
53 
54 static bool isInliningBeneficiary(Operation *op) {
55   return isa<ConstantOp>(op) || isa<DimOp>(op);
56 }
57 
58 // Move arguments of the given kernel function into the function if this reduces
59 // the number of kernel arguments.
60 static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
61                                               gpu::LaunchFuncOp launch) {
62   OpBuilder kernelBuilder(kernelFunc.getBody());
63   auto &firstBlock = kernelFunc.getBody().front();
64   SmallVector<Value, 8> newLaunchArgs;
65   BlockAndValueMapping map;
66   for (int i = 0, e = launch.getNumKernelOperands(); i < e; ++i) {
67     map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i));
68   }
69   for (int i = launch.getNumKernelOperands() - 1; i >= 0; --i) {
70     auto operandOp = launch.getKernelOperand(i).getDefiningOp();
71     if (!operandOp || !isInliningBeneficiary(operandOp)) {
72       newLaunchArgs.push_back(launch.getKernelOperand(i));
73       continue;
74     }
75     // Only inline operations that do not create new arguments.
76     if (!llvm::all_of(operandOp->getOperands(),
77                       [map](Value value) { return map.contains(value); })) {
78       continue;
79     }
80     auto clone = kernelBuilder.clone(*operandOp, map);
81     firstBlock.getArgument(i).replaceAllUsesWith(clone->getResult(0));
82     firstBlock.eraseArgument(i);
83   }
84   if (newLaunchArgs.size() == launch.getNumKernelOperands())
85     return launch;
86 
87   std::reverse(newLaunchArgs.begin(), newLaunchArgs.end());
88   OpBuilder LaunchBuilder(launch);
89   SmallVector<Type, 8> newArgumentTypes;
90   newArgumentTypes.reserve(firstBlock.getNumArguments());
91   for (auto value : firstBlock.getArguments()) {
92     newArgumentTypes.push_back(value.getType());
93   }
94   kernelFunc.setType(LaunchBuilder.getFunctionType(newArgumentTypes, {}));
95   auto newLaunch = LaunchBuilder.create<gpu::LaunchFuncOp>(
96       launch.getLoc(), kernelFunc, launch.getGridSizeOperandValues(),
97       launch.getBlockSizeOperandValues(), newLaunchArgs);
98   launch.erase();
99   return newLaunch;
100 }
101 
102 // Outline the `gpu.launch` operation body into a kernel function. Replace
103 // `gpu.terminator` operations by `gpu.return` in the generated function.
104 static gpu::GPUFuncOp outlineKernelFunc(gpu::LaunchOp launchOp,
105                                         llvm::SetVector<Value> &operands) {
106   Location loc = launchOp.getLoc();
107   // Create a builder with no insertion point, insertion will happen separately
108   // due to symbol table manipulation.
109   OpBuilder builder(launchOp.getContext());
110 
111   // Identify uses from values defined outside of the scope of the launch
112   // operation.
113   getUsedValuesDefinedAbove(launchOp.body(), operands);
114 
115   SmallVector<Type, 4> kernelOperandTypes;
116   kernelOperandTypes.reserve(operands.size());
117   for (Value operand : operands) {
118     kernelOperandTypes.push_back(operand.getType());
119   }
120   FunctionType type =
121       FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
122   std::string kernelFuncName =
123       Twine(launchOp.getParentOfType<FuncOp>().getName(), "_kernel").str();
124   auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFuncName, type);
125   outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
126                        builder.getUnitAttr());
127   outlinedFunc.body().takeBody(launchOp.body());
128   injectGpuIndexOperations(loc, outlinedFunc.body());
129   Block &entryBlock = outlinedFunc.body().front();
130   for (Value operand : operands) {
131     BlockArgument newArg = entryBlock.addArgument(operand.getType());
132     replaceAllUsesInRegionWith(operand, newArg, outlinedFunc.body());
133   }
134   outlinedFunc.walk([](gpu::TerminatorOp op) {
135     OpBuilder replacer(op);
136     replacer.create<gpu::ReturnOp>(op.getLoc());
137     op.erase();
138   });
139 
140   return outlinedFunc;
141 }
142 
143 // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
144 // `kernelFunc`. The kernel func contains the body of the `gpu.launch` with
145 // constant region arguments inlined.
146 static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp,
147                                   gpu::GPUFuncOp kernelFunc,
148                                   ValueRange operands) {
149   OpBuilder builder(launchOp);
150   auto launchFuncOp = builder.create<gpu::LaunchFuncOp>(
151       launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
152       launchOp.getBlockSizeOperandValues(), operands);
153   inlineBeneficiaryOps(kernelFunc, launchFuncOp);
154   launchOp.erase();
155 }
156 
157 namespace {
158 
159 /// Pass that moves the kernel of each LaunchOp into its separate nested module.
160 ///
161 /// This pass moves the kernel code of each LaunchOp into a function created
162 /// inside a nested module. It also creates an external function of the same
163 /// name in the parent module.
164 ///
165 /// The gpu.modules are intended to be compiled to a cubin blob independently in
166 /// a separate pass. The external functions can then be annotated with the
167 /// symbol of the cubin accessor function.
168 class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
169 public:
170   void runOnModule() override {
171     SymbolTable symbolTable(getModule());
172     bool modified = false;
173     for (auto func : getModule().getOps<FuncOp>()) {
174       // Insert just after the function.
175       Block::iterator insertPt(func.getOperation()->getNextNode());
176       func.walk([&](gpu::LaunchOp op) {
177         llvm::SetVector<Value> operands;
178         gpu::GPUFuncOp outlinedFunc = outlineKernelFunc(op, operands);
179 
180         // Create nested module and insert outlinedFunc. The module will
181         // originally get the same name as the function, but may be renamed on
182         // insertion into the parent module.
183         auto kernelModule = createKernelModule(outlinedFunc, symbolTable);
184         symbolTable.insert(kernelModule, insertPt);
185 
186         // Potentially changes signature, pulling in constants.
187         convertToLaunchFuncOp(op, outlinedFunc, operands.getArrayRef());
188         modified = true;
189       });
190     }
191 
192     // If any new module was inserted in this module, annotate this module as
193     // a container module.
194     if (modified)
195       getModule().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
196                           UnitAttr::get(&getContext()));
197   }
198 
199 private:
200   // Returns a gpu.module containing kernelFunc and all callees (recursive).
201   gpu::GPUModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc,
202                                       const SymbolTable &parentSymbolTable) {
203     // TODO: This code cannot use an OpBuilder because it must be inserted into
204     // a SymbolTable by the caller. SymbolTable needs to be refactored to
205     // prevent manual building of Ops with symbols in code using SymbolTables
206     // and then this needs to use the OpBuilder.
207     auto context = getModule().getContext();
208     Builder builder(context);
209     OperationState state(kernelFunc.getLoc(),
210                          gpu::GPUModuleOp::getOperationName());
211     gpu::GPUModuleOp::build(&builder, state, kernelFunc.getName());
212     auto kernelModule = cast<gpu::GPUModuleOp>(Operation::create(state));
213     SymbolTable symbolTable(kernelModule);
214     symbolTable.insert(kernelFunc);
215 
216     SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
217     while (!symbolDefWorklist.empty()) {
218       if (Optional<SymbolTable::UseRange> symbolUses =
219               SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
220         for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
221           StringRef symbolName =
222               symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
223           if (symbolTable.lookup(symbolName))
224             continue;
225 
226           Operation *symbolDefClone =
227               parentSymbolTable.lookup(symbolName)->clone();
228           symbolDefWorklist.push_back(symbolDefClone);
229           symbolTable.insert(symbolDefClone);
230         }
231       }
232     }
233 
234     return kernelModule;
235   }
236 };
237 
238 } // namespace
239 
240 std::unique_ptr<OpPassBase<ModuleOp>> mlir::createGpuKernelOutliningPass() {
241   return std::make_unique<GpuKernelOutliningPass>();
242 }
243 
244 static PassRegistration<GpuKernelOutliningPass>
245     pass("gpu-kernel-outlining",
246          "Outline gpu.launch bodies to kernel functions.");
247