1ace01605SRiver Riddle //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// 2ace01605SRiver Riddle // 3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ace01605SRiver Riddle // 7ace01605SRiver Riddle //===----------------------------------------------------------------------===// 8ace01605SRiver Riddle // 9ace01605SRiver Riddle // This file implements a pass to convert MLIR standard and builtin dialects 10ace01605SRiver Riddle // into the LLVM IR dialect. 11ace01605SRiver Riddle // 12ace01605SRiver Riddle //===----------------------------------------------------------------------===// 13ace01605SRiver Riddle 14ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 1567d0d7acSMichele Scuttari 16c4769ef5SMatthias Springer #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 17ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 18ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/Pattern.h" 193be3883eSBenjamin Maxwell #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" 20ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 21ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 22ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 23ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 24ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h" 25ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h" 2667d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h" 27ace01605SRiver Riddle #include "mlir/Transforms/DialectConversion.h" 28448adfeeSTres Popp #include "llvm/ADT/StringRef.h" 29ace01605SRiver Riddle #include <functional> 30ace01605SRiver Riddle 3167d0d7acSMichele Scuttari namespace mlir { 32cd4ca2d7SMarkus Böck #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS 3367d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 3467d0d7acSMichele Scuttari } // namespace mlir 3567d0d7acSMichele Scuttari 36ace01605SRiver Riddle using namespace mlir; 37ace01605SRiver Riddle 38ace01605SRiver Riddle #define PASS_NAME "convert-cf-to-llvm" 39ace01605SRiver Riddle 40ace01605SRiver Riddle namespace { 4123aa5a74SRiver Riddle /// Lower `cf.assert`. The default lowering calls the `abort` function if the 42ace01605SRiver Riddle /// assertion is violated and has no effect otherwise. The failure message is 43ace01605SRiver Riddle /// ignored by the default lowering but should be propagated by any custom 44ace01605SRiver Riddle /// lowering. 45ace01605SRiver Riddle struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { 46206fad0eSMatthias Springer explicit AssertOpLowering(const LLVMTypeConverter &typeConverter, 47325b58d5SMatthias Springer bool abortOnFailedAssert = true) 48325b58d5SMatthias Springer : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), 49325b58d5SMatthias Springer abortOnFailedAssert(abortOnFailedAssert) {} 50ace01605SRiver Riddle 51ace01605SRiver Riddle LogicalResult 52ace01605SRiver Riddle matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 53ace01605SRiver Riddle ConversionPatternRewriter &rewriter) const override { 54ace01605SRiver Riddle auto loc = op.getLoc(); 55ace01605SRiver Riddle auto module = op->getParentOfType<ModuleOp>(); 56325b58d5SMatthias Springer 57325b58d5SMatthias Springer // Split block at `assert` operation. 58325b58d5SMatthias Springer Block *opBlock = rewriter.getInsertionBlock(); 59325b58d5SMatthias Springer auto opPosition = rewriter.getInsertionPoint(); 60325b58d5SMatthias Springer Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); 61325b58d5SMatthias Springer 62325b58d5SMatthias Springer // Failed block: Generate IR to print the message and call `abort`. 63325b58d5SMatthias Springer Block *failureBlock = rewriter.createBlock(opBlock->getParent()); 64*e84f6b6aSLuohao Wang auto createResult = LLVM::createPrintStrCall( 65*e84f6b6aSLuohao Wang rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(), 66*e84f6b6aSLuohao Wang /*addNewLine=*/false, 673be3883eSBenjamin Maxwell /*runtimeFunctionName=*/"puts"); 68*e84f6b6aSLuohao Wang if (createResult.failed()) 69*e84f6b6aSLuohao Wang return failure(); 70*e84f6b6aSLuohao Wang 71325b58d5SMatthias Springer if (abortOnFailedAssert) { 72325b58d5SMatthias Springer // Insert the `abort` declaration if necessary. 73ace01605SRiver Riddle auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); 74ace01605SRiver Riddle if (!abortFunc) { 75ace01605SRiver Riddle OpBuilder::InsertionGuard guard(rewriter); 76ace01605SRiver Riddle rewriter.setInsertionPointToStart(module.getBody()); 77ace01605SRiver Riddle auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); 78ace01605SRiver Riddle abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), 79ace01605SRiver Riddle "abort", abortFuncTy); 80ace01605SRiver Riddle } 811a36588eSKazu Hirata rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt); 82ace01605SRiver Riddle rewriter.create<LLVM::UnreachableOp>(loc); 83325b58d5SMatthias Springer } else { 84325b58d5SMatthias Springer rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock); 85325b58d5SMatthias Springer } 86ace01605SRiver Riddle 87ace01605SRiver Riddle // Generate assertion test. 88ace01605SRiver Riddle rewriter.setInsertionPointToEnd(opBlock); 89ace01605SRiver Riddle rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 90ace01605SRiver Riddle op, adaptor.getArg(), continuationBlock, failureBlock); 91ace01605SRiver Riddle 92ace01605SRiver Riddle return success(); 93ace01605SRiver Riddle } 94325b58d5SMatthias Springer 95325b58d5SMatthias Springer private: 96325b58d5SMatthias Springer /// If set to `false`, messages are printed but program execution continues. 97325b58d5SMatthias Springer /// This is useful for testing asserts. 98325b58d5SMatthias Springer bool abortOnFailedAssert = true; 99ace01605SRiver Riddle }; 100ace01605SRiver Riddle 101eb6c4197SMatthias Springer /// Helper function for converting branch ops. This function converts the 102eb6c4197SMatthias Springer /// signature of the given block. If the new block signature is different from 103eb6c4197SMatthias Springer /// `expectedTypes`, returns "failure". 104eb6c4197SMatthias Springer static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, 105eb6c4197SMatthias Springer const TypeConverter *converter, 106eb6c4197SMatthias Springer Operation *branchOp, Block *block, 107eb6c4197SMatthias Springer TypeRange expectedTypes) { 108eb6c4197SMatthias Springer assert(converter && "expected non-null type converter"); 109eb6c4197SMatthias Springer assert(!block->isEntryBlock() && "entry blocks have no predecessors"); 110eb6c4197SMatthias Springer 111eb6c4197SMatthias Springer // There is nothing to do if the types already match. 112eb6c4197SMatthias Springer if (block->getArgumentTypes() == expectedTypes) 113eb6c4197SMatthias Springer return block; 114eb6c4197SMatthias Springer 115eb6c4197SMatthias Springer // Compute the new block argument types and convert the block. 116eb6c4197SMatthias Springer std::optional<TypeConverter::SignatureConversion> conversion = 117eb6c4197SMatthias Springer converter->convertBlockSignature(block); 118eb6c4197SMatthias Springer if (!conversion) 119eb6c4197SMatthias Springer return rewriter.notifyMatchFailure(branchOp, 120eb6c4197SMatthias Springer "could not compute block signature"); 121eb6c4197SMatthias Springer if (expectedTypes != conversion->getConvertedTypes()) 122eb6c4197SMatthias Springer return rewriter.notifyMatchFailure( 123eb6c4197SMatthias Springer branchOp, 124eb6c4197SMatthias Springer "mismatch between adaptor operand types and computed block signature"); 125eb6c4197SMatthias Springer return rewriter.applySignatureConversion(block, *conversion, converter); 126448adfeeSTres Popp } 127448adfeeSTres Popp 128eb6c4197SMatthias Springer /// Convert the destination block signature (if necessary) and lower the branch 129eb6c4197SMatthias Springer /// op to llvm.br. 130448adfeeSTres Popp struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { 131448adfeeSTres Popp using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; 132ace01605SRiver Riddle 133ace01605SRiver Riddle LogicalResult 134448adfeeSTres Popp matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, 135ace01605SRiver Riddle ConversionPatternRewriter &rewriter) const override { 136eb6c4197SMatthias Springer FailureOr<Block *> convertedBlock = 137eb6c4197SMatthias Springer getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), 138eb6c4197SMatthias Springer TypeRange(adaptor.getOperands())); 139eb6c4197SMatthias Springer if (failed(convertedBlock)) 140448adfeeSTres Popp return failure(); 141eb6c4197SMatthias Springer Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 142eb6c4197SMatthias Springer op, adaptor.getOperands(), *convertedBlock); 143eb6c4197SMatthias Springer // TODO: We should not just forward all attributes like that. But there are 144eb6c4197SMatthias Springer // existing Flang tests that depend on this behavior. 145eb6c4197SMatthias Springer newOp->setAttrs(op->getAttrDictionary()); 146ace01605SRiver Riddle return success(); 147ace01605SRiver Riddle } 148ace01605SRiver Riddle }; 149ace01605SRiver Riddle 150eb6c4197SMatthias Springer /// Convert the destination block signatures (if necessary) and lower the 151eb6c4197SMatthias Springer /// branch op to llvm.cond_br. 152448adfeeSTres Popp struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { 153448adfeeSTres Popp using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; 154448adfeeSTres Popp 155448adfeeSTres Popp LogicalResult 156448adfeeSTres Popp matchAndRewrite(cf::CondBranchOp op, 157448adfeeSTres Popp typename cf::CondBranchOp::Adaptor adaptor, 158448adfeeSTres Popp ConversionPatternRewriter &rewriter) const override { 159eb6c4197SMatthias Springer FailureOr<Block *> convertedTrueBlock = 160eb6c4197SMatthias Springer getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), 161eb6c4197SMatthias Springer TypeRange(adaptor.getTrueDestOperands())); 162eb6c4197SMatthias Springer if (failed(convertedTrueBlock)) 163448adfeeSTres Popp return failure(); 164eb6c4197SMatthias Springer FailureOr<Block *> convertedFalseBlock = 165eb6c4197SMatthias Springer getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), 166eb6c4197SMatthias Springer TypeRange(adaptor.getFalseDestOperands())); 167eb6c4197SMatthias Springer if (failed(convertedFalseBlock)) 168448adfeeSTres Popp return failure(); 169eb6c4197SMatthias Springer Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 170eb6c4197SMatthias Springer op, adaptor.getCondition(), *convertedTrueBlock, 171eb6c4197SMatthias Springer adaptor.getTrueDestOperands(), *convertedFalseBlock, 172eb6c4197SMatthias Springer adaptor.getFalseDestOperands()); 173eb6c4197SMatthias Springer // TODO: We should not just forward all attributes like that. But there are 174eb6c4197SMatthias Springer // existing Flang tests that depend on this behavior. 175eb6c4197SMatthias Springer newOp->setAttrs(op->getAttrDictionary()); 176448adfeeSTres Popp return success(); 177448adfeeSTres Popp } 178ace01605SRiver Riddle }; 179448adfeeSTres Popp 180eb6c4197SMatthias Springer /// Convert the destination block signatures (if necessary) and lower the 181eb6c4197SMatthias Springer /// switch op to llvm.switch. 182448adfeeSTres Popp struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { 183448adfeeSTres Popp using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; 184448adfeeSTres Popp 185448adfeeSTres Popp LogicalResult 186448adfeeSTres Popp matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, 187448adfeeSTres Popp ConversionPatternRewriter &rewriter) const override { 188eb6c4197SMatthias Springer // Get or convert default block. 189eb6c4197SMatthias Springer FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( 190eb6c4197SMatthias Springer rewriter, getTypeConverter(), op, op.getDefaultDestination(), 191eb6c4197SMatthias Springer TypeRange(adaptor.getDefaultOperands())); 192eb6c4197SMatthias Springer if (failed(convertedDefaultBlock)) 193448adfeeSTres Popp return failure(); 194448adfeeSTres Popp 195eb6c4197SMatthias Springer // Get or convert all case blocks. 196eb6c4197SMatthias Springer SmallVector<Block *> caseDestinations; 197eb6c4197SMatthias Springer SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands(); 198eb6c4197SMatthias Springer for (auto it : llvm::enumerate(op.getCaseDestinations())) { 199eb6c4197SMatthias Springer Block *b = it.value(); 200eb6c4197SMatthias Springer FailureOr<Block *> convertedBlock = 201eb6c4197SMatthias Springer getConvertedBlock(rewriter, getTypeConverter(), op, b, 202eb6c4197SMatthias Springer TypeRange(caseOperands[it.index()])); 203eb6c4197SMatthias Springer if (failed(convertedBlock)) 204448adfeeSTres Popp return failure(); 205eb6c4197SMatthias Springer caseDestinations.push_back(*convertedBlock); 206448adfeeSTres Popp } 207448adfeeSTres Popp 208448adfeeSTres Popp rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 209eb6c4197SMatthias Springer op, adaptor.getFlag(), *convertedDefaultBlock, 210eb6c4197SMatthias Springer adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(), 211eb6c4197SMatthias Springer caseDestinations, caseOperands); 212448adfeeSTres Popp return success(); 213448adfeeSTres Popp } 214ace01605SRiver Riddle }; 215ace01605SRiver Riddle 216ace01605SRiver Riddle } // namespace 217ace01605SRiver Riddle 218fcb0294bSAlex Zinenko void mlir::cf::populateControlFlowToLLVMConversionPatterns( 219206fad0eSMatthias Springer const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 220ace01605SRiver Riddle // clang-format off 221ace01605SRiver Riddle patterns.add< 222ace01605SRiver Riddle BranchOpLowering, 223ace01605SRiver Riddle CondBranchOpLowering, 224ace01605SRiver Riddle SwitchOpLowering>(converter); 225ace01605SRiver Riddle // clang-format on 226ace01605SRiver Riddle } 227ace01605SRiver Riddle 228325b58d5SMatthias Springer void mlir::cf::populateAssertToLLVMConversionPattern( 229206fad0eSMatthias Springer const LLVMTypeConverter &converter, RewritePatternSet &patterns, 230325b58d5SMatthias Springer bool abortOnFailure) { 231325b58d5SMatthias Springer patterns.add<AssertOpLowering>(converter, abortOnFailure); 232325b58d5SMatthias Springer } 233325b58d5SMatthias Springer 234ace01605SRiver Riddle //===----------------------------------------------------------------------===// 235ace01605SRiver Riddle // Pass Definition 236ace01605SRiver Riddle //===----------------------------------------------------------------------===// 237ace01605SRiver Riddle 238ace01605SRiver Riddle namespace { 239ace01605SRiver Riddle /// A pass converting MLIR operations into the LLVM IR dialect. 240039b969bSMichele Scuttari struct ConvertControlFlowToLLVM 241cd4ca2d7SMarkus Böck : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> { 242cd4ca2d7SMarkus Böck 243cd4ca2d7SMarkus Böck using Base::Base; 244ace01605SRiver Riddle 245ace01605SRiver Riddle /// Run the dialect converter on the module. 246ace01605SRiver Riddle void runOnOperation() override { 247eb6c4197SMatthias Springer MLIRContext *ctx = &getContext(); 248eb6c4197SMatthias Springer LLVMConversionTarget target(*ctx); 249eb6c4197SMatthias Springer // This pass lowers only CF dialect ops, but it also modifies block 250eb6c4197SMatthias Springer // signatures inside other ops. These ops should be treated as legal. They 251eb6c4197SMatthias Springer // are lowered by other passes. 252eb6c4197SMatthias Springer target.markUnknownOpDynamicallyLegal([&](Operation *op) { 253eb6c4197SMatthias Springer return op->getDialect() != 254eb6c4197SMatthias Springer ctx->getLoadedDialect<cf::ControlFlowDialect>(); 255eb6c4197SMatthias Springer }); 256ace01605SRiver Riddle 257eb6c4197SMatthias Springer LowerToLLVMOptions options(ctx); 258ace01605SRiver Riddle if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 259ace01605SRiver Riddle options.overrideIndexBitwidth(indexBitwidth); 260ace01605SRiver Riddle 261eb6c4197SMatthias Springer LLVMTypeConverter converter(ctx, options); 262eb6c4197SMatthias Springer RewritePatternSet patterns(ctx); 263fcb0294bSAlex Zinenko mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 264599c7399SMatthias Springer mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns); 265fcb0294bSAlex Zinenko 266ace01605SRiver Riddle if (failed(applyPartialConversion(getOperation(), target, 267ace01605SRiver Riddle std::move(patterns)))) 268ace01605SRiver Riddle signalPassFailure(); 269ace01605SRiver Riddle } 270ace01605SRiver Riddle }; 271ace01605SRiver Riddle } // namespace 272c4769ef5SMatthias Springer 273c4769ef5SMatthias Springer //===----------------------------------------------------------------------===// 274c4769ef5SMatthias Springer // ConvertToLLVMPatternInterface implementation 275c4769ef5SMatthias Springer //===----------------------------------------------------------------------===// 276c4769ef5SMatthias Springer 277c4769ef5SMatthias Springer namespace { 278c4769ef5SMatthias Springer /// Implement the interface to convert MemRef to LLVM. 279c4769ef5SMatthias Springer struct ControlFlowToLLVMDialectInterface 280c4769ef5SMatthias Springer : public ConvertToLLVMPatternInterface { 281c4769ef5SMatthias Springer using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 282c4769ef5SMatthias Springer void loadDependentDialects(MLIRContext *context) const final { 283c4769ef5SMatthias Springer context->loadDialect<LLVM::LLVMDialect>(); 284c4769ef5SMatthias Springer } 285c4769ef5SMatthias Springer 286c4769ef5SMatthias Springer /// Hook for derived dialect interface to provide conversion patterns 287c4769ef5SMatthias Springer /// and mark dialect legal for the conversion target. 288c4769ef5SMatthias Springer void populateConvertToLLVMConversionPatterns( 289c4769ef5SMatthias Springer ConversionTarget &target, LLVMTypeConverter &typeConverter, 290c4769ef5SMatthias Springer RewritePatternSet &patterns) const final { 291c4769ef5SMatthias Springer mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, 292c4769ef5SMatthias Springer patterns); 293599c7399SMatthias Springer mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); 294c4769ef5SMatthias Springer } 295c4769ef5SMatthias Springer }; 296c4769ef5SMatthias Springer } // namespace 297c4769ef5SMatthias Springer 298c4769ef5SMatthias Springer void mlir::cf::registerConvertControlFlowToLLVMInterface( 299c4769ef5SMatthias Springer DialectRegistry ®istry) { 300c4769ef5SMatthias Springer registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { 301c4769ef5SMatthias Springer dialect->addInterfaces<ControlFlowToLLVMDialectInterface>(); 302c4769ef5SMatthias Springer }); 303c4769ef5SMatthias Springer } 304