xref: /llvm-project/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp (revision 4562e389a43caa2e30ebf277c12743edafe6a0ac)
1 //===- KernelOutlining.cpp - Implementation of GPU kernel outlining -------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements the GPU dialect kernel outlining pass.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/Dialect/GPU/GPUDialect.h"
23 #include "mlir/Dialect/GPU/Passes.h"
24 #include "mlir/Dialect/StandardOps/Ops.h"
25 #include "mlir/IR/BlockAndValueMapping.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/SymbolTable.h"
28 #include "mlir/Pass/Pass.h"
29 
30 using namespace mlir;
31 
32 template <typename OpTy>
33 static void createForAllDimensions(OpBuilder &builder, Location loc,
34                                    SmallVectorImpl<Value *> &values) {
35   for (StringRef dim : {"x", "y", "z"}) {
36     Value *v = builder.create<OpTy>(loc, builder.getIndexType(),
37                                     builder.getStringAttr(dim));
38     values.push_back(v);
39   }
40 }
41 
42 // Add operations generating block/thread ids and grid/block dimensions at the
43 // beginning of the `body` region and replace uses of the respective function
44 // arguments.
45 static void injectGpuIndexOperations(Location loc, Region &body) {
46   OpBuilder builder(loc->getContext());
47   Block &firstBlock = body.front();
48   builder.setInsertionPointToStart(&firstBlock);
49   SmallVector<Value *, 12> indexOps;
50   createForAllDimensions<gpu::BlockIdOp>(builder, loc, indexOps);
51   createForAllDimensions<gpu::ThreadIdOp>(builder, loc, indexOps);
52   createForAllDimensions<gpu::GridDimOp>(builder, loc, indexOps);
53   createForAllDimensions<gpu::BlockDimOp>(builder, loc, indexOps);
54   // Replace the leading 12 function args with the respective thread/block index
55   // operations. Iterate backwards since args are erased and indices change.
56   for (int i = 11; i >= 0; --i) {
57     firstBlock.getArgument(i)->replaceAllUsesWith(indexOps[i]);
58     firstBlock.eraseArgument(i);
59   }
60 }
61 
62 static bool isInliningBeneficiary(Operation *op) {
63   return isa<ConstantOp>(op) || isa<DimOp>(op);
64 }
65 
66 // Move arguments of the given kernel function into the function if this reduces
67 // the number of kernel arguments.
68 static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
69                                               gpu::LaunchFuncOp launch) {
70   OpBuilder kernelBuilder(kernelFunc.getBody());
71   auto &firstBlock = kernelFunc.getBody().front();
72   SmallVector<Value *, 8> newLaunchArgs;
73   BlockAndValueMapping map;
74   for (int i = 0, e = launch.getNumKernelOperands(); i < e; ++i) {
75     map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i));
76   }
77   for (int i = launch.getNumKernelOperands() - 1; i >= 0; --i) {
78     auto operandOp = launch.getKernelOperand(i)->getDefiningOp();
79     if (!operandOp || !isInliningBeneficiary(operandOp)) {
80       newLaunchArgs.push_back(launch.getKernelOperand(i));
81       continue;
82     }
83     // Only inline operations that do not create new arguments.
84     if (!llvm::all_of(operandOp->getOperands(),
85                       [map](Value *value) { return map.contains(value); })) {
86       continue;
87     }
88     auto clone = kernelBuilder.clone(*operandOp, map);
89     firstBlock.getArgument(i)->replaceAllUsesWith(clone->getResult(0));
90     firstBlock.eraseArgument(i);
91   }
92   if (newLaunchArgs.size() == launch.getNumKernelOperands())
93     return launch;
94 
95   std::reverse(newLaunchArgs.begin(), newLaunchArgs.end());
96   OpBuilder LaunchBuilder(launch);
97   SmallVector<Type, 8> newArgumentTypes;
98   newArgumentTypes.reserve(firstBlock.getNumArguments());
99   for (auto value : firstBlock.getArguments()) {
100     newArgumentTypes.push_back(value->getType());
101   }
102   kernelFunc.setType(LaunchBuilder.getFunctionType(newArgumentTypes, {}));
103   auto newLaunch = LaunchBuilder.create<gpu::LaunchFuncOp>(
104       launch.getLoc(), kernelFunc, launch.getGridSizeOperandValues(),
105       launch.getBlockSizeOperandValues(), newLaunchArgs);
106   launch.erase();
107   return newLaunch;
108 }
109 
110 // Outline the `gpu.launch` operation body into a kernel function. Replace
111 // `gpu.return` operations by `std.return` in the generated function.
112 static gpu::GPUFuncOp outlineKernelFunc(gpu::LaunchOp launchOp) {
113   Location loc = launchOp.getLoc();
114   // Create a builder with no insertion point, insertion will happen separately
115   // due to symbol table manipulation.
116   OpBuilder builder(launchOp.getContext());
117 
118   SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
119   FunctionType type =
120       FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
121   std::string kernelFuncName =
122       Twine(launchOp.getParentOfType<FuncOp>().getName(), "_kernel").str();
123   auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFuncName, type);
124   outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
125                        builder.getUnitAttr());
126   outlinedFunc.body().takeBody(launchOp.body());
127   injectGpuIndexOperations(loc, outlinedFunc.body());
128   return outlinedFunc;
129 }
130 
131 // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
132 // `kernelFunc`. The kernel func contains the body of the `gpu.launch` with
133 // constant region arguments inlined.
134 static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp,
135                                   gpu::GPUFuncOp kernelFunc) {
136   OpBuilder builder(launchOp);
137   auto launchFuncOp = builder.create<gpu::LaunchFuncOp>(
138       launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
139       launchOp.getBlockSizeOperandValues(), launchOp.getKernelOperandValues());
140   inlineBeneficiaryOps(kernelFunc, launchFuncOp);
141   launchOp.erase();
142 }
143 
144 namespace {
145 
146 /// Pass that moves the kernel of each LaunchOp into its separate nested module.
147 ///
148 /// This pass moves the kernel code of each LaunchOp into a function created
149 /// inside a nested module. It also creates an external function of the same
150 /// name in the parent module.
151 ///
152 /// The kernel modules are intended to be compiled to a cubin blob independently
153 /// in a separate pass. The external functions can then be annotated with the
154 /// symbol of the cubin accessor function.
155 class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
156 public:
157   void runOnModule() override {
158     SymbolTable symbolTable(getModule());
159     bool modified = false;
160     for (auto func : getModule().getOps<FuncOp>()) {
161       // Insert just after the function.
162       Block::iterator insertPt(func.getOperation()->getNextNode());
163       func.walk([&](gpu::LaunchOp op) {
164         gpu::GPUFuncOp outlinedFunc = outlineKernelFunc(op);
165 
166         // Create nested module and insert outlinedFunc. The module will
167         // originally get the same name as the function, but may be renamed on
168         // insertion into the parent module.
169         auto kernelModule = createKernelModule(outlinedFunc, symbolTable);
170         symbolTable.insert(kernelModule, insertPt);
171 
172         // Potentially changes signature, pulling in constants.
173         convertToLaunchFuncOp(op, outlinedFunc);
174         modified = true;
175       });
176     }
177 
178     // If any new module was inserted in this module, annotate this module as
179     // a container module.
180     if (modified)
181       getModule().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
182                           UnitAttr::get(&getContext()));
183   }
184 
185 private:
186   // Returns a module containing kernelFunc and all callees (recursive).
187   ModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc,
188                               const SymbolTable &parentSymbolTable) {
189     auto context = getModule().getContext();
190     Builder builder(context);
191     auto kernelModule =
192         ModuleOp::create(builder.getUnknownLoc(), kernelFunc.getName());
193     kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
194                          builder.getUnitAttr());
195     SymbolTable symbolTable(kernelModule);
196     symbolTable.insert(kernelFunc);
197 
198     SmallVector<Operation *, 8> symbolDefWorklist = {kernelFunc};
199     while (!symbolDefWorklist.empty()) {
200       if (Optional<SymbolTable::UseRange> symbolUses =
201               SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
202         for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
203           StringRef symbolName =
204               symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
205           if (symbolTable.lookup(symbolName))
206             continue;
207 
208           Operation *symbolDefClone =
209               parentSymbolTable.lookup(symbolName)->clone();
210           symbolDefWorklist.push_back(symbolDefClone);
211           symbolTable.insert(symbolDefClone);
212         }
213       }
214     }
215 
216     return kernelModule;
217   }
218 };
219 
220 } // namespace
221 
222 std::unique_ptr<OpPassBase<ModuleOp>> mlir::createGpuKernelOutliningPass() {
223   return std::make_unique<GpuKernelOutliningPass>();
224 }
225 
226 static PassRegistration<GpuKernelOutliningPass>
227     pass("gpu-kernel-outlining",
228          "Outline gpu.launch bodies to kernel functions.");
229