1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR standard and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 #include "../PassDetail.h" 16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 17 #include "mlir/Conversion/LLVMCommon/Pattern.h" 18 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include <functional> 26 27 using namespace mlir; 28 29 #define PASS_NAME "convert-cf-to-llvm" 30 31 namespace { 32 /// Lower `std.assert`. The default lowering calls the `abort` function if the 33 /// assertion is violated and has no effect otherwise. The failure message is 34 /// ignored by the default lowering but should be propagated by any custom 35 /// lowering. 36 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { 37 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern; 38 39 LogicalResult 40 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 41 ConversionPatternRewriter &rewriter) const override { 42 auto loc = op.getLoc(); 43 44 // Insert the `abort` declaration if necessary. 45 auto module = op->getParentOfType<ModuleOp>(); 46 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); 47 if (!abortFunc) { 48 OpBuilder::InsertionGuard guard(rewriter); 49 rewriter.setInsertionPointToStart(module.getBody()); 50 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); 51 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), 52 "abort", abortFuncTy); 53 } 54 55 // Split block at `assert` operation. 56 Block *opBlock = rewriter.getInsertionBlock(); 57 auto opPosition = rewriter.getInsertionPoint(); 58 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); 59 60 // Generate IR to call `abort`. 61 Block *failureBlock = rewriter.createBlock(opBlock->getParent()); 62 rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None); 63 rewriter.create<LLVM::UnreachableOp>(loc); 64 65 // Generate assertion test. 66 rewriter.setInsertionPointToEnd(opBlock); 67 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 68 op, adaptor.getArg(), continuationBlock, failureBlock); 69 70 return success(); 71 } 72 }; 73 74 // Base class for LLVM IR lowering terminator operations with successors. 75 template <typename SourceOp, typename TargetOp> 76 struct OneToOneLLVMTerminatorLowering 77 : public ConvertOpToLLVMPattern<SourceOp> { 78 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 79 using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>; 80 81 LogicalResult 82 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 83 ConversionPatternRewriter &rewriter) const override { 84 rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(), 85 op->getSuccessors(), op->getAttrs()); 86 return success(); 87 } 88 }; 89 90 // FIXME: this should be tablegen'ed as well. 91 struct BranchOpLowering 92 : public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> { 93 using Base::Base; 94 }; 95 struct CondBranchOpLowering 96 : public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> { 97 using Base::Base; 98 }; 99 struct SwitchOpLowering 100 : public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> { 101 using Base::Base; 102 }; 103 104 } // namespace 105 106 void mlir::cf::populateControlFlowToLLVMConversionPatterns( 107 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 108 // clang-format off 109 patterns.add< 110 AssertOpLowering, 111 BranchOpLowering, 112 CondBranchOpLowering, 113 SwitchOpLowering>(converter); 114 // clang-format on 115 } 116 117 //===----------------------------------------------------------------------===// 118 // Pass Definition 119 //===----------------------------------------------------------------------===// 120 121 namespace { 122 /// A pass converting MLIR operations into the LLVM IR dialect. 123 struct ConvertControlFlowToLLVM 124 : public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> { 125 ConvertControlFlowToLLVM() = default; 126 127 /// Run the dialect converter on the module. 128 void runOnOperation() override { 129 LLVMConversionTarget target(getContext()); 130 RewritePatternSet patterns(&getContext()); 131 132 LowerToLLVMOptions options(&getContext()); 133 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 134 options.overrideIndexBitwidth(indexBitwidth); 135 136 LLVMTypeConverter converter(&getContext(), options); 137 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 138 139 if (failed(applyPartialConversion(getOperation(), target, 140 std::move(patterns)))) 141 signalPassFailure(); 142 } 143 }; 144 } // namespace 145 146 std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() { 147 return std::make_unique<ConvertControlFlowToLLVM>(); 148 } 149