xref: /llvm-project/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1ace01605SRiver Riddle //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
2ace01605SRiver Riddle //
3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ace01605SRiver Riddle //
7ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8ace01605SRiver Riddle //
9ace01605SRiver Riddle // This file implements patterns to convert standard dialect to SPIR-V dialect.
10ace01605SRiver Riddle //
11ace01605SRiver Riddle //===----------------------------------------------------------------------===//
12ace01605SRiver Riddle 
13ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
14ace01605SRiver Riddle #include "../SPIRVCommon/Pattern.h"
15ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20ace01605SRiver Riddle #include "mlir/IR/AffineMap.h"
2113644f0bSJakub Kuderski #include "mlir/IR/PatternMatch.h"
2213644f0bSJakub Kuderski #include "mlir/Transforms/DialectConversion.h"
23ace01605SRiver Riddle #include "llvm/Support/Debug.h"
2413644f0bSJakub Kuderski #include "llvm/Support/FormatVariadic.h"
25ace01605SRiver Riddle 
26ace01605SRiver Riddle #define DEBUG_TYPE "cf-to-spirv-pattern"
27ace01605SRiver Riddle 
28ace01605SRiver Riddle using namespace mlir;
29ace01605SRiver Riddle 
30c8bc72dcSXiang Li /// Legailze target block arguments.
31c8bc72dcSXiang Li static LogicalResult legalizeBlockArguments(Block &block, Operation *op,
3213644f0bSJakub Kuderski                                             PatternRewriter &rewriter,
3313644f0bSJakub Kuderski                                             const TypeConverter &converter) {
34c8bc72dcSXiang Li   auto builder = OpBuilder::atBlockBegin(&block);
35c8bc72dcSXiang Li   for (unsigned i = 0; i < block.getNumArguments(); ++i) {
3626a0b277SMehdi Amini     BlockArgument arg = block.getArgument(i);
37c8bc72dcSXiang Li     if (converter.isLegal(arg.getType()))
38c8bc72dcSXiang Li       continue;
39c8bc72dcSXiang Li     Type ty = arg.getType();
40c8bc72dcSXiang Li     Type newTy = converter.convertType(ty);
41c8bc72dcSXiang Li     if (!newTy) {
4213644f0bSJakub Kuderski       return rewriter.notifyMatchFailure(
43c8bc72dcSXiang Li           op, llvm::formatv("failed to legalize type for argument {0})", arg));
4413644f0bSJakub Kuderski     }
45c8bc72dcSXiang Li     unsigned argNum = arg.getArgNumber();
46c8bc72dcSXiang Li     Location loc = arg.getLoc();
47c8bc72dcSXiang Li     Value newArg = block.insertArgument(argNum, newTy, loc);
48c8bc72dcSXiang Li     Value convertedValue = converter.materializeSourceConversion(
49c8bc72dcSXiang Li         builder, op->getLoc(), ty, newArg);
50c8bc72dcSXiang Li     if (!convertedValue) {
51c8bc72dcSXiang Li       return rewriter.notifyMatchFailure(
52c8bc72dcSXiang Li           op, llvm::formatv("failed to cast new argument {0} to type {1})",
53c8bc72dcSXiang Li                             newArg, ty));
54c8bc72dcSXiang Li     }
55c8bc72dcSXiang Li     arg.replaceAllUsesWith(convertedValue);
56c8bc72dcSXiang Li     block.eraseArgument(argNum + 1);
5713644f0bSJakub Kuderski   }
5813644f0bSJakub Kuderski   return success();
5913644f0bSJakub Kuderski }
6013644f0bSJakub Kuderski 
61ace01605SRiver Riddle //===----------------------------------------------------------------------===//
62ace01605SRiver Riddle // Operation conversion
63ace01605SRiver Riddle //===----------------------------------------------------------------------===//
64ace01605SRiver Riddle 
65ace01605SRiver Riddle namespace {
665ab6ef75SJakub Kuderski /// Converts cf.br to spirv.Branch.
6713644f0bSJakub Kuderski struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> {
6813644f0bSJakub Kuderski   using OpConversionPattern::OpConversionPattern;
69ace01605SRiver Riddle 
70ace01605SRiver Riddle   LogicalResult
71ace01605SRiver Riddle   matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor,
72ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
73c8bc72dcSXiang Li     if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter,
7413644f0bSJakub Kuderski                                       *getTypeConverter())))
7513644f0bSJakub Kuderski       return failure();
7613644f0bSJakub Kuderski 
77ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(),
78ace01605SRiver Riddle                                                  adaptor.getDestOperands());
79ace01605SRiver Riddle     return success();
80ace01605SRiver Riddle   }
81ace01605SRiver Riddle };
82ace01605SRiver Riddle 
835ab6ef75SJakub Kuderski /// Converts cf.cond_br to spirv.BranchConditional.
8413644f0bSJakub Kuderski struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> {
8513644f0bSJakub Kuderski   using OpConversionPattern::OpConversionPattern;
86ace01605SRiver Riddle 
87ace01605SRiver Riddle   LogicalResult
88ace01605SRiver Riddle   matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor,
89ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
90c8bc72dcSXiang Li     if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter,
9113644f0bSJakub Kuderski                                       *getTypeConverter())))
9213644f0bSJakub Kuderski       return failure();
9313644f0bSJakub Kuderski 
94c8bc72dcSXiang Li     if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter,
9513644f0bSJakub Kuderski                                       *getTypeConverter())))
9613644f0bSJakub Kuderski       return failure();
9713644f0bSJakub Kuderski 
98ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>(
9913644f0bSJakub Kuderski         op, adaptor.getCondition(), op.getTrueDest(),
10013644f0bSJakub Kuderski         adaptor.getTrueDestOperands(), op.getFalseDest(),
10113644f0bSJakub Kuderski         adaptor.getFalseDestOperands());
102ace01605SRiver Riddle     return success();
103ace01605SRiver Riddle   }
104ace01605SRiver Riddle };
105ace01605SRiver Riddle } // namespace
106ace01605SRiver Riddle 
107ace01605SRiver Riddle //===----------------------------------------------------------------------===//
108ace01605SRiver Riddle // Pattern population
109ace01605SRiver Riddle //===----------------------------------------------------------------------===//
110ace01605SRiver Riddle 
111ace01605SRiver Riddle void mlir::cf::populateControlFlowToSPIRVPatterns(
112*206fad0eSMatthias Springer     const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
113ace01605SRiver Riddle   MLIRContext *context = patterns.getContext();
114ace01605SRiver Riddle 
115ace01605SRiver Riddle   patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context);
116ace01605SRiver Riddle }
117