1ace01605SRiver Riddle //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===// 2ace01605SRiver Riddle // 3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ace01605SRiver Riddle // 7ace01605SRiver Riddle //===----------------------------------------------------------------------===// 8ace01605SRiver Riddle // 9ace01605SRiver Riddle // This file implements patterns to convert standard dialect to SPIR-V dialect. 10ace01605SRiver Riddle // 11ace01605SRiver Riddle //===----------------------------------------------------------------------===// 12ace01605SRiver Riddle 13ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" 14ace01605SRiver Riddle #include "../SPIRVCommon/Pattern.h" 15ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 16ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 19ace01605SRiver Riddle #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 20ace01605SRiver Riddle #include "mlir/IR/AffineMap.h" 2113644f0bSJakub Kuderski #include "mlir/IR/PatternMatch.h" 2213644f0bSJakub Kuderski #include "mlir/Transforms/DialectConversion.h" 23ace01605SRiver Riddle #include "llvm/Support/Debug.h" 2413644f0bSJakub Kuderski #include "llvm/Support/FormatVariadic.h" 25ace01605SRiver Riddle 26ace01605SRiver Riddle #define DEBUG_TYPE "cf-to-spirv-pattern" 27ace01605SRiver Riddle 28ace01605SRiver Riddle using namespace mlir; 29ace01605SRiver Riddle 30c8bc72dcSXiang Li /// Legailze target block arguments. 31c8bc72dcSXiang Li static LogicalResult legalizeBlockArguments(Block &block, Operation *op, 3213644f0bSJakub Kuderski PatternRewriter &rewriter, 3313644f0bSJakub Kuderski const TypeConverter &converter) { 34c8bc72dcSXiang Li auto builder = OpBuilder::atBlockBegin(&block); 35c8bc72dcSXiang Li for (unsigned i = 0; i < block.getNumArguments(); ++i) { 3626a0b277SMehdi Amini BlockArgument arg = block.getArgument(i); 37c8bc72dcSXiang Li if (converter.isLegal(arg.getType())) 38c8bc72dcSXiang Li continue; 39c8bc72dcSXiang Li Type ty = arg.getType(); 40c8bc72dcSXiang Li Type newTy = converter.convertType(ty); 41c8bc72dcSXiang Li if (!newTy) { 4213644f0bSJakub Kuderski return rewriter.notifyMatchFailure( 43c8bc72dcSXiang Li op, llvm::formatv("failed to legalize type for argument {0})", arg)); 4413644f0bSJakub Kuderski } 45c8bc72dcSXiang Li unsigned argNum = arg.getArgNumber(); 46c8bc72dcSXiang Li Location loc = arg.getLoc(); 47c8bc72dcSXiang Li Value newArg = block.insertArgument(argNum, newTy, loc); 48c8bc72dcSXiang Li Value convertedValue = converter.materializeSourceConversion( 49c8bc72dcSXiang Li builder, op->getLoc(), ty, newArg); 50c8bc72dcSXiang Li if (!convertedValue) { 51c8bc72dcSXiang Li return rewriter.notifyMatchFailure( 52c8bc72dcSXiang Li op, llvm::formatv("failed to cast new argument {0} to type {1})", 53c8bc72dcSXiang Li newArg, ty)); 54c8bc72dcSXiang Li } 55c8bc72dcSXiang Li arg.replaceAllUsesWith(convertedValue); 56c8bc72dcSXiang Li block.eraseArgument(argNum + 1); 5713644f0bSJakub Kuderski } 5813644f0bSJakub Kuderski return success(); 5913644f0bSJakub Kuderski } 6013644f0bSJakub Kuderski 61ace01605SRiver Riddle //===----------------------------------------------------------------------===// 62ace01605SRiver Riddle // Operation conversion 63ace01605SRiver Riddle //===----------------------------------------------------------------------===// 64ace01605SRiver Riddle 65ace01605SRiver Riddle namespace { 665ab6ef75SJakub Kuderski /// Converts cf.br to spirv.Branch. 6713644f0bSJakub Kuderski struct BranchOpPattern final : OpConversionPattern<cf::BranchOp> { 6813644f0bSJakub Kuderski using OpConversionPattern::OpConversionPattern; 69ace01605SRiver Riddle 70ace01605SRiver Riddle LogicalResult 71ace01605SRiver Riddle matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor, 72ace01605SRiver Riddle ConversionPatternRewriter &rewriter) const override { 73c8bc72dcSXiang Li if (failed(legalizeBlockArguments(*op.getDest(), op, rewriter, 7413644f0bSJakub Kuderski *getTypeConverter()))) 7513644f0bSJakub Kuderski return failure(); 7613644f0bSJakub Kuderski 77ace01605SRiver Riddle rewriter.replaceOpWithNewOp<spirv::BranchOp>(op, op.getDest(), 78ace01605SRiver Riddle adaptor.getDestOperands()); 79ace01605SRiver Riddle return success(); 80ace01605SRiver Riddle } 81ace01605SRiver Riddle }; 82ace01605SRiver Riddle 835ab6ef75SJakub Kuderski /// Converts cf.cond_br to spirv.BranchConditional. 8413644f0bSJakub Kuderski struct CondBranchOpPattern final : OpConversionPattern<cf::CondBranchOp> { 8513644f0bSJakub Kuderski using OpConversionPattern::OpConversionPattern; 86ace01605SRiver Riddle 87ace01605SRiver Riddle LogicalResult 88ace01605SRiver Riddle matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor, 89ace01605SRiver Riddle ConversionPatternRewriter &rewriter) const override { 90c8bc72dcSXiang Li if (failed(legalizeBlockArguments(*op.getTrueDest(), op, rewriter, 9113644f0bSJakub Kuderski *getTypeConverter()))) 9213644f0bSJakub Kuderski return failure(); 9313644f0bSJakub Kuderski 94c8bc72dcSXiang Li if (failed(legalizeBlockArguments(*op.getFalseDest(), op, rewriter, 9513644f0bSJakub Kuderski *getTypeConverter()))) 9613644f0bSJakub Kuderski return failure(); 9713644f0bSJakub Kuderski 98ace01605SRiver Riddle rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>( 9913644f0bSJakub Kuderski op, adaptor.getCondition(), op.getTrueDest(), 10013644f0bSJakub Kuderski adaptor.getTrueDestOperands(), op.getFalseDest(), 10113644f0bSJakub Kuderski adaptor.getFalseDestOperands()); 102ace01605SRiver Riddle return success(); 103ace01605SRiver Riddle } 104ace01605SRiver Riddle }; 105ace01605SRiver Riddle } // namespace 106ace01605SRiver Riddle 107ace01605SRiver Riddle //===----------------------------------------------------------------------===// 108ace01605SRiver Riddle // Pattern population 109ace01605SRiver Riddle //===----------------------------------------------------------------------===// 110ace01605SRiver Riddle 111ace01605SRiver Riddle void mlir::cf::populateControlFlowToSPIRVPatterns( 112*206fad0eSMatthias Springer const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { 113ace01605SRiver Riddle MLIRContext *context = patterns.getContext(); 114ace01605SRiver Riddle 115ace01605SRiver Riddle patterns.add<BranchOpPattern, CondBranchOpPattern>(typeConverter, context); 116ace01605SRiver Riddle } 117