xref: /llvm-project/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert standard dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
14 #include "../SPIRVCommon/Pattern.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
25 
26 #define DEBUG_TYPE "cf-to-spirv-pattern"
27 
28 using namespace mlir;
29 
30 /// Legailze target block arguments.
31 static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
32                                             PatternRewriter &rewriter,
33                                             const TypeConverter &converter) {
34   auto builder = OpBuilder::atBlockBegin(&block);
35   for (unsigned i = 0; i < block.getNumArguments(); ++i) {
36     BlockArgument arg = block.getArgument(i);
37     if (converter.isLegal(arg.getType()))
38       continue;
39     Type ty = arg.getType();
40     Type newTy = converter.convertType(ty);
41     if (!newTy) {
42       return rewriter.notifyMatchFailure(
43           op, llvm::formatv("failed to legalize type for argument {0})", arg));
44     }
45     unsigned argNum = arg.getArgNumber();
46     Location loc = arg.getLoc();
47     Value newArg = block.insertArgument(argNum, newTy, loc);
48     Value convertedValue = converter.materializeSourceConversion(
49         builder, op->getLoc(), ty, newArg);
50     if (!convertedValue) {
51       return rewriter.notifyMatchFailure(
52           op, llvm::formatv("failed to cast new argument {0} to type {1})",
53                             newArg, ty));
54     }
55     arg.replaceAllUsesWith(convertedValue);
56     block.eraseArgument(argNum + 1);
57   }
58   return success();
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // Operation conversion
63 //===----------------------------------------------------------------------===//
64 
65 namespace {
66 /// Converts cf.br to spirv.Branch.
67 struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
68   using OpConversionPattern::OpConversionPattern;
69 
70   LogicalResult
71   matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
72                   ConversionPatternRewriter &rewriter) const override {
73     if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
74                                       *getTypeConverter())))
75       return failure();
76 
77     rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
78                                                  adaptor.getDestOperands());
79     return success();
80   }
81 };
82 
83 /// Converts cf.cond_br to spirv.BranchConditional.
84 struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
85   using OpConversionPattern::OpConversionPattern;
86 
87   LogicalResult
88   matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
89                   ConversionPatternRewriter &rewriter) const override {
90     if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
91                                       *getTypeConverter())))
92       return failure();
93 
94     if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
95                                       *getTypeConverter())))
96       return failure();
97 
98     rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
99         op, adaptor.getCondition(), op.getTrueDest(),
100         adaptor.getTrueDestOperands(), op.getFalseDest(),
101         adaptor.getFalseDestOperands());
102     return success();
103   }
104 };
105 } // namespace
106 
107 //===----------------------------------------------------------------------===//
108 // Pattern population
109 //===----------------------------------------------------------------------===//
110 
111 void mlir::cf::populateControlFlowToSPIRVPatterns(
112     const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
113   MLIRContext *context = patterns.getContext();
114 
115   patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
116 }
117