1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR standard and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" 20 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 22 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 24 #include "mlir/IR/BuiltinOps.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "llvm/ADT/StringRef.h" 29 #include <functional> 30 31 namespace mlir { 32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS 33 #include "mlir/Conversion/Passes.h.inc" 34 } // namespace mlir 35 36 using namespace mlir; 37 38 #define PASS_NAME "convert-cf-to-llvm" 39 40 namespace { 41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the 42 /// assertion is violated and has no effect otherwise. The failure message is 43 /// ignored by the default lowering but should be propagated by any custom 44 /// lowering. 45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { 46 explicit AssertOpLowering(const LLVMTypeConverter &typeConverter, 47 bool abortOnFailedAssert = true) 48 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), 49 abortOnFailedAssert(abortOnFailedAssert) {} 50 51 LogicalResult 52 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 53 ConversionPatternRewriter &rewriter) const override { 54 auto loc = op.getLoc(); 55 auto module = op->getParentOfType<ModuleOp>(); 56 57 // Split block at `assert` operation. 58 Block *opBlock = rewriter.getInsertionBlock(); 59 auto opPosition = rewriter.getInsertionPoint(); 60 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); 61 62 // Failed block: Generate IR to print the message and call `abort`. 63 Block *failureBlock = rewriter.createBlock(opBlock->getParent()); 64 auto createResult = LLVM::createPrintStrCall( 65 rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(), 66 /*addNewLine=*/false, 67 /*runtimeFunctionName=*/"puts"); 68 if (createResult.failed()) 69 return failure(); 70 71 if (abortOnFailedAssert) { 72 // Insert the `abort` declaration if necessary. 73 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); 74 if (!abortFunc) { 75 OpBuilder::InsertionGuard guard(rewriter); 76 rewriter.setInsertionPointToStart(module.getBody()); 77 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); 78 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), 79 "abort", abortFuncTy); 80 } 81 rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt); 82 rewriter.create<LLVM::UnreachableOp>(loc); 83 } else { 84 rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock); 85 } 86 87 // Generate assertion test. 88 rewriter.setInsertionPointToEnd(opBlock); 89 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 90 op, adaptor.getArg(), continuationBlock, failureBlock); 91 92 return success(); 93 } 94 95 private: 96 /// If set to `false`, messages are printed but program execution continues. 97 /// This is useful for testing asserts. 98 bool abortOnFailedAssert = true; 99 }; 100 101 /// Helper function for converting branch ops. This function converts the 102 /// signature of the given block. If the new block signature is different from 103 /// `expectedTypes`, returns "failure". 104 static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, 105 const TypeConverter *converter, 106 Operation *branchOp, Block *block, 107 TypeRange expectedTypes) { 108 assert(converter && "expected non-null type converter"); 109 assert(!block->isEntryBlock() && "entry blocks have no predecessors"); 110 111 // There is nothing to do if the types already match. 112 if (block->getArgumentTypes() == expectedTypes) 113 return block; 114 115 // Compute the new block argument types and convert the block. 116 std::optional<TypeConverter::SignatureConversion> conversion = 117 converter->convertBlockSignature(block); 118 if (!conversion) 119 return rewriter.notifyMatchFailure(branchOp, 120 "could not compute block signature"); 121 if (expectedTypes != conversion->getConvertedTypes()) 122 return rewriter.notifyMatchFailure( 123 branchOp, 124 "mismatch between adaptor operand types and computed block signature"); 125 return rewriter.applySignatureConversion(block, *conversion, converter); 126 } 127 128 /// Convert the destination block signature (if necessary) and lower the branch 129 /// op to llvm.br. 130 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { 131 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; 132 133 LogicalResult 134 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, 135 ConversionPatternRewriter &rewriter) const override { 136 FailureOr<Block *> convertedBlock = 137 getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), 138 TypeRange(adaptor.getOperands())); 139 if (failed(convertedBlock)) 140 return failure(); 141 Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( 142 op, adaptor.getOperands(), *convertedBlock); 143 // TODO: We should not just forward all attributes like that. But there are 144 // existing Flang tests that depend on this behavior. 145 newOp->setAttrs(op->getAttrDictionary()); 146 return success(); 147 } 148 }; 149 150 /// Convert the destination block signatures (if necessary) and lower the 151 /// branch op to llvm.cond_br. 152 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { 153 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; 154 155 LogicalResult 156 matchAndRewrite(cf::CondBranchOp op, 157 typename cf::CondBranchOp::Adaptor adaptor, 158 ConversionPatternRewriter &rewriter) const override { 159 FailureOr<Block *> convertedTrueBlock = 160 getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), 161 TypeRange(adaptor.getTrueDestOperands())); 162 if (failed(convertedTrueBlock)) 163 return failure(); 164 FailureOr<Block *> convertedFalseBlock = 165 getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), 166 TypeRange(adaptor.getFalseDestOperands())); 167 if (failed(convertedFalseBlock)) 168 return failure(); 169 Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 170 op, adaptor.getCondition(), *convertedTrueBlock, 171 adaptor.getTrueDestOperands(), *convertedFalseBlock, 172 adaptor.getFalseDestOperands()); 173 // TODO: We should not just forward all attributes like that. But there are 174 // existing Flang tests that depend on this behavior. 175 newOp->setAttrs(op->getAttrDictionary()); 176 return success(); 177 } 178 }; 179 180 /// Convert the destination block signatures (if necessary) and lower the 181 /// switch op to llvm.switch. 182 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { 183 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; 184 185 LogicalResult 186 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, 187 ConversionPatternRewriter &rewriter) const override { 188 // Get or convert default block. 189 FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( 190 rewriter, getTypeConverter(), op, op.getDefaultDestination(), 191 TypeRange(adaptor.getDefaultOperands())); 192 if (failed(convertedDefaultBlock)) 193 return failure(); 194 195 // Get or convert all case blocks. 196 SmallVector<Block *> caseDestinations; 197 SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands(); 198 for (auto it : llvm::enumerate(op.getCaseDestinations())) { 199 Block *b = it.value(); 200 FailureOr<Block *> convertedBlock = 201 getConvertedBlock(rewriter, getTypeConverter(), op, b, 202 TypeRange(caseOperands[it.index()])); 203 if (failed(convertedBlock)) 204 return failure(); 205 caseDestinations.push_back(*convertedBlock); 206 } 207 208 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 209 op, adaptor.getFlag(), *convertedDefaultBlock, 210 adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(), 211 caseDestinations, caseOperands); 212 return success(); 213 } 214 }; 215 216 } // namespace 217 218 void mlir::cf::populateControlFlowToLLVMConversionPatterns( 219 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 220 // clang-format off 221 patterns.add< 222 BranchOpLowering, 223 CondBranchOpLowering, 224 SwitchOpLowering>(converter); 225 // clang-format on 226 } 227 228 void mlir::cf::populateAssertToLLVMConversionPattern( 229 const LLVMTypeConverter &converter, RewritePatternSet &patterns, 230 bool abortOnFailure) { 231 patterns.add<AssertOpLowering>(converter, abortOnFailure); 232 } 233 234 //===----------------------------------------------------------------------===// 235 // Pass Definition 236 //===----------------------------------------------------------------------===// 237 238 namespace { 239 /// A pass converting MLIR operations into the LLVM IR dialect. 240 struct ConvertControlFlowToLLVM 241 : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> { 242 243 using Base::Base; 244 245 /// Run the dialect converter on the module. 246 void runOnOperation() override { 247 MLIRContext *ctx = &getContext(); 248 LLVMConversionTarget target(*ctx); 249 // This pass lowers only CF dialect ops, but it also modifies block 250 // signatures inside other ops. These ops should be treated as legal. They 251 // are lowered by other passes. 252 target.markUnknownOpDynamicallyLegal([&](Operation *op) { 253 return op->getDialect() != 254 ctx->getLoadedDialect<cf::ControlFlowDialect>(); 255 }); 256 257 LowerToLLVMOptions options(ctx); 258 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 259 options.overrideIndexBitwidth(indexBitwidth); 260 261 LLVMTypeConverter converter(ctx, options); 262 RewritePatternSet patterns(ctx); 263 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 264 mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns); 265 266 if (failed(applyPartialConversion(getOperation(), target, 267 std::move(patterns)))) 268 signalPassFailure(); 269 } 270 }; 271 } // namespace 272 273 //===----------------------------------------------------------------------===// 274 // ConvertToLLVMPatternInterface implementation 275 //===----------------------------------------------------------------------===// 276 277 namespace { 278 /// Implement the interface to convert MemRef to LLVM. 279 struct ControlFlowToLLVMDialectInterface 280 : public ConvertToLLVMPatternInterface { 281 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 282 void loadDependentDialects(MLIRContext *context) const final { 283 context->loadDialect<LLVM::LLVMDialect>(); 284 } 285 286 /// Hook for derived dialect interface to provide conversion patterns 287 /// and mark dialect legal for the conversion target. 288 void populateConvertToLLVMConversionPatterns( 289 ConversionTarget &target, LLVMTypeConverter &typeConverter, 290 RewritePatternSet &patterns) const final { 291 mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, 292 patterns); 293 mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); 294 } 295 }; 296 } // namespace 297 298 void mlir::cf::registerConvertControlFlowToLLVMInterface( 299 DialectRegistry ®istry) { 300 registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { 301 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>(); 302 }); 303 } 304