1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR standard and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" 20 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 22 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 24 #include "mlir/IR/BuiltinOps.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "llvm/ADT/StringRef.h" 29 #include <functional> 30 31 namespace mlir { 32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS 33 #include "mlir/Conversion/Passes.h.inc" 34 } // namespace mlir 35 36 using namespace mlir; 37 38 #define PASS_NAME "convert-cf-to-llvm" 39 40 namespace { 41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the 42 /// assertion is violated and has no effect otherwise. The failure message is 43 /// ignored by the default lowering but should be propagated by any custom 44 /// lowering. 45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { 46 explicit AssertOpLowering(LLVMTypeConverter &typeConverter, 47 bool abortOnFailedAssert = true) 48 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), 49 abortOnFailedAssert(abortOnFailedAssert) {} 50 51 LogicalResult 52 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 53 ConversionPatternRewriter &rewriter) const override { 54 auto loc = op.getLoc(); 55 auto module = op->getParentOfType<ModuleOp>(); 56 57 // Split block at `assert` operation. 58 Block *opBlock = rewriter.getInsertionBlock(); 59 auto opPosition = rewriter.getInsertionPoint(); 60 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); 61 62 // Failed block: Generate IR to print the message and call `abort`. 63 Block *failureBlock = rewriter.createBlock(opBlock->getParent()); 64 LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(), 65 *getTypeConverter(), /*addNewLine=*/false, 66 /*runtimeFunctionName=*/"puts"); 67 if (abortOnFailedAssert) { 68 // Insert the `abort` declaration if necessary. 69 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); 70 if (!abortFunc) { 71 OpBuilder::InsertionGuard guard(rewriter); 72 rewriter.setInsertionPointToStart(module.getBody()); 73 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); 74 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), 75 "abort", abortFuncTy); 76 } 77 rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt); 78 rewriter.create<LLVM::UnreachableOp>(loc); 79 } else { 80 rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock); 81 } 82 83 // Generate assertion test. 84 rewriter.setInsertionPointToEnd(opBlock); 85 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 86 op, adaptor.getArg(), continuationBlock, failureBlock); 87 88 return success(); 89 } 90 91 private: 92 /// If set to `false`, messages are printed but program execution continues. 93 /// This is useful for testing asserts. 94 bool abortOnFailedAssert = true; 95 }; 96 97 /// The cf->LLVM lowerings for branching ops require that the blocks they jump 98 /// to first have updated types which should be handled by a pattern operating 99 /// on the parent op. 100 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter, 101 ValueRange operands, 102 ValueRange blockArgs, Location loc, 103 llvm::StringRef messagePrefix) { 104 for (const auto &idxAndTypes : 105 llvm::enumerate(llvm::zip(blockArgs, operands))) { 106 int64_t i = idxAndTypes.index(); 107 Value argValue = 108 rewriter.getRemappedValue(std::get<0>(idxAndTypes.value())); 109 Type operandType = std::get<1>(idxAndTypes.value()).getType(); 110 // In the case of an invalid jump, the block argument will have been 111 // remapped to an UnrealizedConversionCast. In the case of a valid jump, 112 // there might still be a no-op conversion cast with both types being equal. 113 // Consider both of these details to see if the jump would be invalid. 114 if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>( 115 argValue.getDefiningOp())) { 116 if (op.getOperandTypes().front() != operandType) { 117 return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) { 118 diag << messagePrefix; 119 diag << "mismatched types from operand # " << i << " "; 120 diag << operandType; 121 diag << " not compatible with destination block argument type "; 122 diag << op.getOperandTypes().front(); 123 diag << " which should be converted with the parent op."; 124 }); 125 } 126 } 127 } 128 return success(); 129 } 130 131 /// Ensure that all block types were updated and then create an LLVM::BrOp 132 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { 133 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; 134 135 LogicalResult 136 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, 137 ConversionPatternRewriter &rewriter) const override { 138 if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(), 139 op.getSuccessor()->getArguments(), 140 op.getLoc(), 141 /*messagePrefix=*/""))) 142 return failure(); 143 144 rewriter.replaceOpWithNewOp<LLVM::BrOp>( 145 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 146 return success(); 147 } 148 }; 149 150 /// Ensure that all block types were updated and then create an LLVM::CondBrOp 151 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { 152 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; 153 154 LogicalResult 155 matchAndRewrite(cf::CondBranchOp op, 156 typename cf::CondBranchOp::Adaptor adaptor, 157 ConversionPatternRewriter &rewriter) const override { 158 if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(), 159 op.getFalseDest()->getArguments(), 160 op.getLoc(), "in false case branch "))) 161 return failure(); 162 if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(), 163 op.getTrueDest()->getArguments(), 164 op.getLoc(), "in true case branch "))) 165 return failure(); 166 167 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 168 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 169 return success(); 170 } 171 }; 172 173 /// Ensure that all block types were updated and then create an LLVM::SwitchOp 174 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { 175 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; 176 177 LogicalResult 178 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, 179 ConversionPatternRewriter &rewriter) const override { 180 if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(), 181 op.getDefaultDestination()->getArguments(), 182 op.getLoc(), "in switch default case "))) 183 return failure(); 184 185 for (const auto &i : llvm::enumerate( 186 llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) { 187 if (failed(verifyMatchingValues( 188 rewriter, std::get<0>(i.value()), 189 std::get<1>(i.value())->getArguments(), op.getLoc(), 190 "in switch case " + std::to_string(i.index()) + " "))) { 191 return failure(); 192 } 193 } 194 195 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 196 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 197 return success(); 198 } 199 }; 200 201 } // namespace 202 203 void mlir::cf::populateControlFlowToLLVMConversionPatterns( 204 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 205 // clang-format off 206 patterns.add< 207 AssertOpLowering, 208 BranchOpLowering, 209 CondBranchOpLowering, 210 SwitchOpLowering>(converter); 211 // clang-format on 212 } 213 214 void mlir::cf::populateAssertToLLVMConversionPattern( 215 LLVMTypeConverter &converter, RewritePatternSet &patterns, 216 bool abortOnFailure) { 217 patterns.add<AssertOpLowering>(converter, abortOnFailure); 218 } 219 220 //===----------------------------------------------------------------------===// 221 // Pass Definition 222 //===----------------------------------------------------------------------===// 223 224 namespace { 225 /// A pass converting MLIR operations into the LLVM IR dialect. 226 struct ConvertControlFlowToLLVM 227 : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> { 228 229 using Base::Base; 230 231 /// Run the dialect converter on the module. 232 void runOnOperation() override { 233 LLVMConversionTarget target(getContext()); 234 RewritePatternSet patterns(&getContext()); 235 236 LowerToLLVMOptions options(&getContext()); 237 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 238 options.overrideIndexBitwidth(indexBitwidth); 239 240 LLVMTypeConverter converter(&getContext(), options); 241 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 242 243 if (failed(applyPartialConversion(getOperation(), target, 244 std::move(patterns)))) 245 signalPassFailure(); 246 } 247 }; 248 } // namespace 249 250 //===----------------------------------------------------------------------===// 251 // ConvertToLLVMPatternInterface implementation 252 //===----------------------------------------------------------------------===// 253 254 namespace { 255 /// Implement the interface to convert MemRef to LLVM. 256 struct ControlFlowToLLVMDialectInterface 257 : public ConvertToLLVMPatternInterface { 258 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 259 void loadDependentDialects(MLIRContext *context) const final { 260 context->loadDialect<LLVM::LLVMDialect>(); 261 } 262 263 /// Hook for derived dialect interface to provide conversion patterns 264 /// and mark dialect legal for the conversion target. 265 void populateConvertToLLVMConversionPatterns( 266 ConversionTarget &target, LLVMTypeConverter &typeConverter, 267 RewritePatternSet &patterns) const final { 268 mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, 269 patterns); 270 } 271 }; 272 } // namespace 273 274 void mlir::cf::registerConvertControlFlowToLLVMInterface( 275 DialectRegistry ®istry) { 276 registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { 277 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>(); 278 }); 279 } 280