1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR standard and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" 20 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 22 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 24 #include "mlir/IR/BuiltinOps.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "llvm/ADT/StringRef.h" 29 #include <functional> 30 31 namespace mlir { 32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS 33 #include "mlir/Conversion/Passes.h.inc" 34 } // namespace mlir 35 36 using namespace mlir; 37 38 #define PASS_NAME "convert-cf-to-llvm" 39 40 namespace { 41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the 42 /// assertion is violated and has no effect otherwise. The failure message is 43 /// ignored by the default lowering but should be propagated by any custom 44 /// lowering. 45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { 46 explicit AssertOpLowering(const LLVMTypeConverter &typeConverter, 47 bool abortOnFailedAssert = true) 48 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), 49 abortOnFailedAssert(abortOnFailedAssert) {} 50 51 LogicalResult 52 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 53 ConversionPatternRewriter &rewriter) const override { 54 auto loc = op.getLoc(); 55 auto module = op->getParentOfType<ModuleOp>(); 56 57 // Split block at `assert` operation. 58 Block *opBlock = rewriter.getInsertionBlock(); 59 auto opPosition = rewriter.getInsertionPoint(); 60 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); 61 62 // Failed block: Generate IR to print the message and call `abort`. 63 Block *failureBlock = rewriter.createBlock(opBlock->getParent()); 64 LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(), 65 *getTypeConverter(), /*addNewLine=*/false, 66 /*runtimeFunctionName=*/"puts"); 67 if (abortOnFailedAssert) { 68 // Insert the `abort` declaration if necessary. 69 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); 70 if (!abortFunc) { 71 OpBuilder::InsertionGuard guard(rewriter); 72 rewriter.setInsertionPointToStart(module.getBody()); 73 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); 74 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), 75 "abort", abortFuncTy); 76 } 77 rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt); 78 rewriter.create<LLVM::UnreachableOp>(loc); 79 } else { 80 rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock); 81 } 82 83 // Generate assertion test. 84 rewriter.setInsertionPointToEnd(opBlock); 85 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 86 op, adaptor.getArg(), continuationBlock, failureBlock); 87 88 return success(); 89 } 90 91 private: 92 /// If set to `false`, messages are printed but program execution continues. 93 /// This is useful for testing asserts. 94 bool abortOnFailedAssert = true; 95 }; 96 97 /// Helper function for converting branch ops. This function converts the 98 /// signature of the given block. If the new block signature is different from 99 /// `expectedTypes`, returns "failure". 100 static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, 101 const TypeConverter *converter, 102 Operation *branchOp, Block *block, 103 TypeRange expectedTypes) { 104 assert(converter && "expected non-null type converter"); 105 assert(!block->isEntryBlock() && "entry blocks have no predecessors"); 106 107 // There is nothing to do if the types already match. 108 if (block->getArgumentTypes() == expectedTypes) 109 return block; 110 111 // Compute the new block argument types and convert the block. 112 std::optional<TypeConverter::SignatureConversion> conversion = 113 converter->convertBlockSignature(block); 114 if (!conversion) 115 return rewriter.notifyMatchFailure(branchOp, 116 "could not compute block signature"); 117 if (expectedTypes != conversion->getConvertedTypes()) 118 return rewriter.notifyMatchFailure( 119 branchOp, 120 "mismatch between adaptor operand types and computed block signature"); 121 return rewriter.applySignatureConversion(block, *conversion, converter); 122 } 123 124 /// Convert the destination block signature (if necessary) and lower the branch 125 /// op to llvm.br. 126 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { 127 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; 128 129 LogicalResult 130 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, 131 ConversionPatternRewriter &rewriter) const override { 132 FailureOr<Block *> convertedBlock = 133 getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), 134 TypeRange(adaptor.getOperands())); 135 if (failed(convertedBlock)) 136 return failure(); 137 Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 138 op, adaptor.getOperands(), *convertedBlock); 139 // TODO: We should not just forward all attributes like that. But there are 140 // existing Flang tests that depend on this behavior. 141 newOp->setAttrs(op->getAttrDictionary()); 142 return success(); 143 } 144 }; 145 146 /// Convert the destination block signatures (if necessary) and lower the 147 /// branch op to llvm.cond_br. 148 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { 149 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; 150 151 LogicalResult 152 matchAndRewrite(cf::CondBranchOp op, 153 typename cf::CondBranchOp::Adaptor adaptor, 154 ConversionPatternRewriter &rewriter) const override { 155 FailureOr<Block *> convertedTrueBlock = 156 getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), 157 TypeRange(adaptor.getTrueDestOperands())); 158 if (failed(convertedTrueBlock)) 159 return failure(); 160 FailureOr<Block *> convertedFalseBlock = 161 getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), 162 TypeRange(adaptor.getFalseDestOperands())); 163 if (failed(convertedFalseBlock)) 164 return failure(); 165 Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 166 op, adaptor.getCondition(), *convertedTrueBlock, 167 adaptor.getTrueDestOperands(), *convertedFalseBlock, 168 adaptor.getFalseDestOperands()); 169 // TODO: We should not just forward all attributes like that. But there are 170 // existing Flang tests that depend on this behavior. 171 newOp->setAttrs(op->getAttrDictionary()); 172 return success(); 173 } 174 }; 175 176 /// Convert the destination block signatures (if necessary) and lower the 177 /// switch op to llvm.switch. 178 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { 179 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; 180 181 LogicalResult 182 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, 183 ConversionPatternRewriter &rewriter) const override { 184 // Get or convert default block. 185 FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( 186 rewriter, getTypeConverter(), op, op.getDefaultDestination(), 187 TypeRange(adaptor.getDefaultOperands())); 188 if (failed(convertedDefaultBlock)) 189 return failure(); 190 191 // Get or convert all case blocks. 192 SmallVector<Block *> caseDestinations; 193 SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands(); 194 for (auto it : llvm::enumerate(op.getCaseDestinations())) { 195 Block *b = it.value(); 196 FailureOr<Block *> convertedBlock = 197 getConvertedBlock(rewriter, getTypeConverter(), op, b, 198 TypeRange(caseOperands[it.index()])); 199 if (failed(convertedBlock)) 200 return failure(); 201 caseDestinations.push_back(*convertedBlock); 202 } 203 204 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 205 op, adaptor.getFlag(), *convertedDefaultBlock, 206 adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(), 207 caseDestinations, caseOperands); 208 return success(); 209 } 210 }; 211 212 } // namespace 213 214 void mlir::cf::populateControlFlowToLLVMConversionPatterns( 215 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 216 // clang-format off 217 patterns.add< 218 BranchOpLowering, 219 CondBranchOpLowering, 220 SwitchOpLowering>(converter); 221 // clang-format on 222 } 223 224 void mlir::cf::populateAssertToLLVMConversionPattern( 225 const LLVMTypeConverter &converter, RewritePatternSet &patterns, 226 bool abortOnFailure) { 227 patterns.add<AssertOpLowering>(converter, abortOnFailure); 228 } 229 230 //===----------------------------------------------------------------------===// 231 // Pass Definition 232 //===----------------------------------------------------------------------===// 233 234 namespace { 235 /// A pass converting MLIR operations into the LLVM IR dialect. 236 struct ConvertControlFlowToLLVM 237 : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> { 238 239 using Base::Base; 240 241 /// Run the dialect converter on the module. 242 void runOnOperation() override { 243 MLIRContext *ctx = &getContext(); 244 LLVMConversionTarget target(*ctx); 245 // This pass lowers only CF dialect ops, but it also modifies block 246 // signatures inside other ops. These ops should be treated as legal. They 247 // are lowered by other passes. 248 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 249 return op->getDialect() != 250 ctx->getLoadedDialect<cf::ControlFlowDialect>(); 251 }); 252 253 LowerToLLVMOptions options(ctx); 254 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 255 options.overrideIndexBitwidth(indexBitwidth); 256 257 LLVMTypeConverter converter(ctx, options); 258 RewritePatternSet patterns(ctx); 259 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 260 mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns); 261 262 if (failed(applyPartialConversion(getOperation(), target, 263 std::move(patterns)))) 264 signalPassFailure(); 265 } 266 }; 267 } // namespace 268 269 //===----------------------------------------------------------------------===// 270 // ConvertToLLVMPatternInterface implementation 271 //===----------------------------------------------------------------------===// 272 273 namespace { 274 /// Implement the interface to convert MemRef to LLVM. 275 struct ControlFlowToLLVMDialectInterface 276 : public ConvertToLLVMPatternInterface { 277 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 278 void loadDependentDialects(MLIRContext *context) const final { 279 context->loadDialect<LLVM::LLVMDialect>(); 280 } 281 282 /// Hook for derived dialect interface to provide conversion patterns 283 /// and mark dialect legal for the conversion target. 284 void populateConvertToLLVMConversionPatterns( 285 ConversionTarget &target, LLVMTypeConverter &typeConverter, 286 RewritePatternSet &patterns) const final { 287 mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, 288 patterns); 289 mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); 290 } 291 }; 292 } // namespace 293 294 void mlir::cf::registerConvertControlFlowToLLVMInterface( 295 DialectRegistry ®istry) { 296 registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { 297 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>(); 298 }); 299 } 300