xref: /llvm-project/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SCF dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/Support/FormatVariadic.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Context
26 //===----------------------------------------------------------------------===//
27 
28 namespace mlir {
29 struct ScfToSPIRVContextImpl {
30   // Map between the spirv region control flow operation (spirv.mlir.loop or
31   // spirv.mlir.selection) to the VariableOp created to store the region
32   // results. The order of the VariableOp matches the order of the results.
33   DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
34 };
35 } // namespace mlir
36 
37 /// We use ScfToSPIRVContext to store information about the lowering of the scf
38 /// region that need to be used later on. When we lower scf.for/scf.if we create
39 /// VariableOp to store the results. We need to keep track of the VariableOp
40 /// created as we need to insert stores into them when lowering Yield. Those
41 /// StoreOp cannot be created earlier as they may use a different type than
42 /// yield operands.
43 ScfToSPIRVContext::ScfToSPIRVContext() {
44   impl = std::make_unique<::ScfToSPIRVContextImpl>();
45 }
46 
47 ScfToSPIRVContext::~ScfToSPIRVContext() = default;
48 
49 namespace {
50 
51 //===----------------------------------------------------------------------===//
52 // Helper Functions
53 //===----------------------------------------------------------------------===//
54 
55 /// Replaces SCF op outputs with SPIR-V variable loads.
56 /// We create VariableOp to handle the results value of the control flow region.
57 /// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right
58 /// after the loop we load the value from the allocation and use it as the SCF
59 /// op result.
60 template <typename ScfOp, typename OpTy>
61 void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
62                            ConversionPatternRewriter &rewriter,
63                            ScfToSPIRVContextImpl *scfToSPIRVContext,
64                            ArrayRef<Type> returnTypes) {
65 
66   Location loc = scfOp.getLoc();
67   auto &allocas = scfToSPIRVContext->outputVars[newOp];
68   // Clearing the allocas is necessary in case a dialect conversion path failed
69   // previously, and this is the second attempt of this conversion.
70   allocas.clear();
71   SmallVector<Value, 8> resultValue;
72   for (Type convertedType : returnTypes) {
73     auto pointerType =
74         spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
75     rewriter.setInsertionPoint(newOp);
76     auto alloc = rewriter.create<spirv::VariableOp>(
77         loc, pointerType, spirv::StorageClass::Function,
78         /*initializer=*/nullptr);
79     allocas.push_back(alloc);
80     rewriter.setInsertionPointAfter(newOp);
81     Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
82     resultValue.push_back(loadResult);
83   }
84   rewriter.replaceOp(scfOp, resultValue);
85 }
86 
87 Region::iterator getBlockIt(Region &region, unsigned index) {
88   return std::next(region.begin(), index);
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Conversion Patterns
93 //===----------------------------------------------------------------------===//
94 
95 /// Common class for all vector to GPU patterns.
96 template <typename OpTy>
97 class SCFToSPIRVPattern : public OpConversionPattern<OpTy> {
98 public:
99   SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter,
100                     ScfToSPIRVContextImpl *scfToSPIRVContext)
101       : OpConversionPattern<OpTy>::OpConversionPattern(converter, context),
102         scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {}
103 
104 protected:
105   ScfToSPIRVContextImpl *scfToSPIRVContext;
106   // FIXME: We explicitly keep a reference of the type converter here instead of
107   // passing it to OpConversionPattern during construction. This effectively
108   // bypasses the conversion framework's automation on type conversion. This is
109   // needed right now because the conversion framework will unconditionally
110   // legalize all types used by SCF ops upon discovering them, for example, the
111   // types of loop carried values. We use SPIR-V variables for those loop
112   // carried values. Depending on the available capabilities, the SPIR-V
113   // variable can be different, for example, cooperative matrix or normal
114   // variable. We'd like to detach the conversion of the loop carried values
115   // from the SCF ops (which is mainly a region). So we need to "mark" types
116   // used by SCF ops as legal, if to use the conversion framework for type
117   // conversion. There isn't a straightforward way to do that yet, as when
118   // converting types, ops aren't taken into consideration. Therefore, we just
119   // bypass the framework's type conversion for now.
120   const SPIRVTypeConverter &typeConverter;
121 };
122 
123 //===----------------------------------------------------------------------===//
124 // scf::ForOp
125 //===----------------------------------------------------------------------===//
126 
127 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
128 struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
129   using SCFToSPIRVPattern::SCFToSPIRVPattern;
130 
131   LogicalResult
132   matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor,
133                   ConversionPatternRewriter &rewriter) const override {
134     // scf::ForOp can be lowered to the structured control flow represented by
135     // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
136     // latch and the merge block the exit block. The resulting spirv::LoopOp has
137     // a single back edge from the continue to header block, and a single exit
138     // from header to merge.
139     auto loc = forOp.getLoc();
140     auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
141     loopOp.addEntryAndMergeBlock(rewriter);
142 
143     OpBuilder::InsertionGuard guard(rewriter);
144     // Create the block for the header.
145     Block *header = rewriter.createBlock(&loopOp.getBody(),
146                                          getBlockIt(loopOp.getBody(), 1));
147     rewriter.setInsertionPointAfter(loopOp);
148 
149     // Create the new induction variable to use.
150     Value adapLowerBound = adaptor.getLowerBound();
151     BlockArgument newIndVar =
152         header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc());
153     for (Value arg : adaptor.getInitArgs())
154       header->addArgument(arg.getType(), arg.getLoc());
155     Block *body = forOp.getBody();
156 
157     // Apply signature conversion to the body of the forOp. It has a single
158     // block, with argument which is the induction variable. That has to be
159     // replaced with the new induction variable.
160     TypeConverter::SignatureConversion signatureConverter(
161         body->getNumArguments());
162     signatureConverter.remapInput(0, newIndVar);
163     for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
164       signatureConverter.remapInput(i, header->getArgument(i));
165     body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
166                                              signatureConverter);
167 
168     // Move the blocks from the forOp into the loopOp. This is the body of the
169     // loopOp.
170     rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(),
171                                 getBlockIt(loopOp.getBody(), 2));
172 
173     SmallVector<Value, 8> args(1, adaptor.getLowerBound());
174     args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
175     // Branch into it from the entry.
176     rewriter.setInsertionPointToEnd(&(loopOp.getBody().front()));
177     rewriter.create<spirv::BranchOp>(loc, header, args);
178 
179     // Generate the rest of the loop header.
180     rewriter.setInsertionPointToEnd(header);
181     auto *mergeBlock = loopOp.getMergeBlock();
182     auto cmpOp = rewriter.create<spirv::SLessThanOp>(
183         loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound());
184 
185     rewriter.create<spirv::BranchConditionalOp>(
186         loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
187 
188     // Generate instructions to increment the step of the induction variable and
189     // branch to the header.
190     Block *continueBlock = loopOp.getContinueBlock();
191     rewriter.setInsertionPointToEnd(continueBlock);
192 
193     // Add the step to the induction variable and branch to the header.
194     Value updatedIndVar = rewriter.create<spirv::IAddOp>(
195         loc, newIndVar.getType(), newIndVar, adaptor.getStep());
196     rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
197 
198     // Infer the return types from the init operands. Vector type may get
199     // converted to CooperativeMatrix or to Vector type, to avoid having complex
200     // extra logic to figure out the right type we just infer it from the Init
201     // operands.
202     SmallVector<Type, 8> initTypes;
203     for (auto arg : adaptor.getInitArgs())
204       initTypes.push_back(arg.getType());
205     replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext,
206                           initTypes);
207     return success();
208   }
209 };
210 
211 //===----------------------------------------------------------------------===//
212 // scf::IfOp
213 //===----------------------------------------------------------------------===//
214 
215 /// Pattern to convert a scf::IfOp within kernel functions into
216 /// spirv::SelectionOp.
217 struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
218   using SCFToSPIRVPattern::SCFToSPIRVPattern;
219 
220   LogicalResult
221   matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor,
222                   ConversionPatternRewriter &rewriter) const override {
223     // When lowering `scf::IfOp` we explicitly create a selection header block
224     // before the control flow diverges and a merge block where control flow
225     // subsequently converges.
226     auto loc = ifOp.getLoc();
227 
228     // Create `spirv.selection` operation, selection header block and merge
229     // block.
230     auto selectionOp =
231         rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
232     auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
233                                             selectionOp.getBody().end());
234     rewriter.create<spirv::MergeOp>(loc);
235 
236     OpBuilder::InsertionGuard guard(rewriter);
237     auto *selectionHeaderBlock =
238         rewriter.createBlock(&selectionOp.getBody().front());
239 
240     // Inline `then` region before the merge block and branch to it.
241     auto &thenRegion = ifOp.getThenRegion();
242     auto *thenBlock = &thenRegion.front();
243     rewriter.setInsertionPointToEnd(&thenRegion.back());
244     rewriter.create<spirv::BranchOp>(loc, mergeBlock);
245     rewriter.inlineRegionBefore(thenRegion, mergeBlock);
246 
247     auto *elseBlock = mergeBlock;
248     // If `else` region is not empty, inline that region before the merge block
249     // and branch to it.
250     if (!ifOp.getElseRegion().empty()) {
251       auto &elseRegion = ifOp.getElseRegion();
252       elseBlock = &elseRegion.front();
253       rewriter.setInsertionPointToEnd(&elseRegion.back());
254       rewriter.create<spirv::BranchOp>(loc, mergeBlock);
255       rewriter.inlineRegionBefore(elseRegion, mergeBlock);
256     }
257 
258     // Create a `spirv.BranchConditional` operation for selection header block.
259     rewriter.setInsertionPointToEnd(selectionHeaderBlock);
260     rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(),
261                                                 thenBlock, ArrayRef<Value>(),
262                                                 elseBlock, ArrayRef<Value>());
263 
264     SmallVector<Type, 8> returnTypes;
265     for (auto result : ifOp.getResults()) {
266       auto convertedType = typeConverter.convertType(result.getType());
267       if (!convertedType)
268         return rewriter.notifyMatchFailure(
269             loc,
270             llvm::formatv("failed to convert type '{0}'", result.getType()));
271 
272       returnTypes.push_back(convertedType);
273     }
274     replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext,
275                           returnTypes);
276     return success();
277   }
278 };
279 
280 //===----------------------------------------------------------------------===//
281 // scf::YieldOp
282 //===----------------------------------------------------------------------===//
283 
284 struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> {
285 public:
286   using SCFToSPIRVPattern::SCFToSPIRVPattern;
287 
288   LogicalResult
289   matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor,
290                   ConversionPatternRewriter &rewriter) const override {
291     ValueRange operands = adaptor.getOperands();
292 
293     Operation *parent = terminatorOp->getParentOp();
294 
295     // TODO: Implement conversion for the remaining `scf` ops.
296     if (parent->getDialect()->getNamespace() ==
297             scf::SCFDialect::getDialectNamespace() &&
298         !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent))
299       return rewriter.notifyMatchFailure(
300           terminatorOp,
301           llvm::formatv("conversion not supported for parent op: '{0}'",
302                         parent->getName()));
303 
304     // If the region return values, store each value into the associated
305     // VariableOp created during lowering of the parent region.
306     if (!operands.empty()) {
307       auto &allocas = scfToSPIRVContext->outputVars[parent];
308       if (allocas.size() != operands.size())
309         return failure();
310 
311       auto loc = terminatorOp.getLoc();
312       for (unsigned i = 0, e = operands.size(); i < e; i++)
313         rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
314       if (isa<spirv::LoopOp>(parent)) {
315         // For loops we also need to update the branch jumping back to the
316         // header.
317         auto br = cast<spirv::BranchOp>(
318             rewriter.getInsertionBlock()->getTerminator());
319         SmallVector<Value, 8> args(br.getBlockArguments());
320         args.append(operands.begin(), operands.end());
321         rewriter.setInsertionPoint(br);
322         rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
323                                          args);
324         rewriter.eraseOp(br);
325       }
326     }
327     rewriter.eraseOp(terminatorOp);
328     return success();
329   }
330 };
331 
332 //===----------------------------------------------------------------------===//
333 // scf::WhileOp
334 //===----------------------------------------------------------------------===//
335 
336 struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> {
337   using SCFToSPIRVPattern::SCFToSPIRVPattern;
338 
339   LogicalResult
340   matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor,
341                   ConversionPatternRewriter &rewriter) const override {
342     auto loc = whileOp.getLoc();
343     auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None);
344     loopOp.addEntryAndMergeBlock(rewriter);
345 
346     Region &beforeRegion = whileOp.getBefore();
347     Region &afterRegion = whileOp.getAfter();
348 
349     if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) ||
350         failed(rewriter.convertRegionTypes(&afterRegion, typeConverter)))
351       return rewriter.notifyMatchFailure(whileOp,
352                                          "Failed to convert region types");
353 
354     OpBuilder::InsertionGuard guard(rewriter);
355 
356     Block &entryBlock = *loopOp.getEntryBlock();
357     Block &beforeBlock = beforeRegion.front();
358     Block &afterBlock = afterRegion.front();
359     Block &mergeBlock = *loopOp.getMergeBlock();
360 
361     auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator());
362     SmallVector<Value> condArgs;
363     if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs)))
364       return failure();
365 
366     Value conditionVal = rewriter.getRemappedValue(cond.getCondition());
367     if (!conditionVal)
368       return failure();
369 
370     auto yield = cast<scf::YieldOp>(afterBlock.getTerminator());
371     SmallVector<Value> yieldArgs;
372     if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs)))
373       return failure();
374 
375     // Move the while before block as the initial loop header block.
376     rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(),
377                                 getBlockIt(loopOp.getBody(), 1));
378 
379     // Move the while after block as the initial loop body block.
380     rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(),
381                                 getBlockIt(loopOp.getBody(), 2));
382 
383     // Jump from the loop entry block to the loop header block.
384     rewriter.setInsertionPointToEnd(&entryBlock);
385     rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits());
386 
387     auto condLoc = cond.getLoc();
388 
389     SmallVector<Value> resultValues(condArgs.size());
390 
391     // For other SCF ops, the scf.yield op yields the value for the whole SCF
392     // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V
393     // local variables. But for the scf.while op, the scf.yield op yields a
394     // value for the before region, which may not matching the whole op's
395     // result. Instead, the scf.condition op returns values matching the whole
396     // op's results. So we need to create/load/store variables according to
397     // that.
398     for (const auto &it : llvm::enumerate(condArgs)) {
399       auto res = it.value();
400       auto i = it.index();
401       auto pointerType =
402           spirv::PointerType::get(res.getType(), spirv::StorageClass::Function);
403 
404       // Create local variables before the scf.while op.
405       rewriter.setInsertionPoint(loopOp);
406       auto alloc = rewriter.create<spirv::VariableOp>(
407           condLoc, pointerType, spirv::StorageClass::Function,
408           /*initializer=*/nullptr);
409 
410       // Load the final result values after the scf.while op.
411       rewriter.setInsertionPointAfter(loopOp);
412       auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc);
413       resultValues[i] = loadResult;
414 
415       // Store the current iteration's result value.
416       rewriter.setInsertionPointToEnd(&beforeBlock);
417       rewriter.create<spirv::StoreOp>(condLoc, alloc, res);
418     }
419 
420     rewriter.setInsertionPointToEnd(&beforeBlock);
421     rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
422         cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt);
423 
424     // Convert the scf.yield op to a branch back to the header block.
425     rewriter.setInsertionPointToEnd(&afterBlock);
426     rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock,
427                                                  yieldArgs);
428 
429     rewriter.replaceOp(whileOp, resultValues);
430     return success();
431   }
432 };
433 } // namespace
434 
435 //===----------------------------------------------------------------------===//
436 // Public API
437 //===----------------------------------------------------------------------===//
438 
439 void mlir::populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
440                                       ScfToSPIRVContext &scfToSPIRVContext,
441                                       RewritePatternSet &patterns) {
442   patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion,
443                WhileOpConversion>(patterns.getContext(), typeConverter,
444                                   scfToSPIRVContext.getImpl());
445 }
446