1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR standard and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 17 #include "mlir/Conversion/LLVMCommon/Pattern.h" 18 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "llvm/ADT/StringRef.h" 27 #include <functional> 28 29 namespace mlir { 30 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVM 31 #include "mlir/Conversion/Passes.h.inc" 32 } // namespace mlir 33 34 using namespace mlir; 35 36 #define PASS_NAME "convert-cf-to-llvm" 37 38 namespace { 39 /// Lower `cf.assert`. The default lowering calls the `abort` function if the 40 /// assertion is violated and has no effect otherwise. The failure message is 41 /// ignored by the default lowering but should be propagated by any custom 42 /// lowering. 43 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { 44 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern; 45 46 LogicalResult 47 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 48 ConversionPatternRewriter &rewriter) const override { 49 auto loc = op.getLoc(); 50 51 // Insert the `abort` declaration if necessary. 52 auto module = op->getParentOfType<ModuleOp>(); 53 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); 54 if (!abortFunc) { 55 OpBuilder::InsertionGuard guard(rewriter); 56 rewriter.setInsertionPointToStart(module.getBody()); 57 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); 58 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), 59 "abort", abortFuncTy); 60 } 61 62 // Split block at `assert` operation. 63 Block *opBlock = rewriter.getInsertionBlock(); 64 auto opPosition = rewriter.getInsertionPoint(); 65 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); 66 67 // Generate IR to call `abort`. 68 Block *failureBlock = rewriter.createBlock(opBlock->getParent()); 69 rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt); 70 rewriter.create<LLVM::UnreachableOp>(loc); 71 72 // Generate assertion test. 73 rewriter.setInsertionPointToEnd(opBlock); 74 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 75 op, adaptor.getArg(), continuationBlock, failureBlock); 76 77 return success(); 78 } 79 }; 80 81 /// The cf->LLVM lowerings for branching ops require that the blocks they jump 82 /// to first have updated types which should be handled by a pattern operating 83 /// on the parent op. 84 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter, 85 ValueRange operands, 86 ValueRange blockArgs, Location loc, 87 llvm::StringRef messagePrefix) { 88 for (const auto &idxAndTypes : 89 llvm::enumerate(llvm::zip(blockArgs, operands))) { 90 int64_t i = idxAndTypes.index(); 91 Value argValue = 92 rewriter.getRemappedValue(std::get<0>(idxAndTypes.value())); 93 Type operandType = std::get<1>(idxAndTypes.value()).getType(); 94 // In the case of an invalid jump, the block argument will have been 95 // remapped to an UnrealizedConversionCast. In the case of a valid jump, 96 // there might still be a no-op conversion cast with both types being equal. 97 // Consider both of these details to see if the jump would be invalid. 98 if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>( 99 argValue.getDefiningOp())) { 100 if (op.getOperandTypes().front() != operandType) { 101 return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) { 102 diag << messagePrefix; 103 diag << "mismatched types from operand # " << i << " "; 104 diag << operandType; 105 diag << " not compatible with destination block argument type "; 106 diag << op.getOperandTypes().front(); 107 diag << " which should be converted with the parent op."; 108 }); 109 } 110 } 111 } 112 return success(); 113 } 114 115 /// Ensure that all block types were updated and then create an LLVM::BrOp 116 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { 117 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; 118 119 LogicalResult 120 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, 121 ConversionPatternRewriter &rewriter) const override { 122 if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(), 123 op.getSuccessor()->getArguments(), 124 op.getLoc(), 125 /*messagePrefix=*/""))) 126 return failure(); 127 128 rewriter.replaceOpWithNewOp<LLVM::BrOp>( 129 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 130 return success(); 131 } 132 }; 133 134 /// Ensure that all block types were updated and then create an LLVM::CondBrOp 135 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { 136 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; 137 138 LogicalResult 139 matchAndRewrite(cf::CondBranchOp op, 140 typename cf::CondBranchOp::Adaptor adaptor, 141 ConversionPatternRewriter &rewriter) const override { 142 if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(), 143 op.getFalseDest()->getArguments(), 144 op.getLoc(), "in false case branch "))) 145 return failure(); 146 if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(), 147 op.getTrueDest()->getArguments(), 148 op.getLoc(), "in true case branch "))) 149 return failure(); 150 151 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 152 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 153 return success(); 154 } 155 }; 156 157 /// Ensure that all block types were updated and then create an LLVM::SwitchOp 158 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { 159 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; 160 161 LogicalResult 162 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, 163 ConversionPatternRewriter &rewriter) const override { 164 if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(), 165 op.getDefaultDestination()->getArguments(), 166 op.getLoc(), "in switch default case "))) 167 return failure(); 168 169 for (const auto &i : llvm::enumerate( 170 llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) { 171 if (failed(verifyMatchingValues( 172 rewriter, std::get<0>(i.value()), 173 std::get<1>(i.value())->getArguments(), op.getLoc(), 174 "in switch case " + std::to_string(i.index()) + " "))) { 175 return failure(); 176 } 177 } 178 179 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 180 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 181 return success(); 182 } 183 }; 184 185 } // namespace 186 187 void mlir::cf::populateControlFlowToLLVMConversionPatterns( 188 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 189 // clang-format off 190 patterns.add< 191 AssertOpLowering, 192 BranchOpLowering, 193 CondBranchOpLowering, 194 SwitchOpLowering>(converter); 195 // clang-format on 196 } 197 198 //===----------------------------------------------------------------------===// 199 // Pass Definition 200 //===----------------------------------------------------------------------===// 201 202 namespace { 203 /// A pass converting MLIR operations into the LLVM IR dialect. 204 struct ConvertControlFlowToLLVM 205 : public impl::ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> { 206 ConvertControlFlowToLLVM() = default; 207 208 /// Run the dialect converter on the module. 209 void runOnOperation() override { 210 LLVMConversionTarget target(getContext()); 211 RewritePatternSet patterns(&getContext()); 212 213 LowerToLLVMOptions options(&getContext()); 214 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 215 options.overrideIndexBitwidth(indexBitwidth); 216 217 LLVMTypeConverter converter(&getContext(), options); 218 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 219 220 if (failed(applyPartialConversion(getOperation(), target, 221 std::move(patterns)))) 222 signalPassFailure(); 223 } 224 }; 225 } // namespace 226 227 std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() { 228 return std::make_unique<ConvertControlFlowToLLVM>(); 229 } 230