//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert MLIR standard and builtin dialects // into the LLVM IR dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/StringRef.h" #include namespace mlir { #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; #define PASS_NAME "convert-cf-to-llvm" namespace { /// Lower `cf.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is /// ignored by the default lowering but should be propagated by any custom /// lowering. struct AssertOpLowering : public ConvertOpToLLVMPattern { explicit AssertOpLowering(const LLVMTypeConverter &typeConverter, bool abortOnFailedAssert = true) : ConvertOpToLLVMPattern(typeConverter, /*benefit=*/1), abortOnFailedAssert(abortOnFailedAssert) {} LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto module = op->getParentOfType(); // Split block at `assert` operation. Block *opBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); auto createResult = LLVM::createPrintStrCall( rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(), /*addNewLine=*/false, /*runtimeFunctionName=*/"puts"); if (createResult.failed()) return failure(); if (abortOnFailedAssert) { // Insert the `abort` declaration if necessary. auto abortFunc = module.lookupSymbol("abort"); if (!abortFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); abortFunc = rewriter.create(rewriter.getUnknownLoc(), "abort", abortFuncTy); } rewriter.create(loc, abortFunc, std::nullopt); rewriter.create(loc); } else { rewriter.create(loc, ValueRange(), continuationBlock); } // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( op, adaptor.getArg(), continuationBlock, failureBlock); return success(); } private: /// If set to `false`, messages are printed but program execution continues. /// This is useful for testing asserts. bool abortOnFailedAssert = true; }; /// Helper function for converting branch ops. This function converts the /// signature of the given block. If the new block signature is different from /// `expectedTypes`, returns "failure". static FailureOr getConvertedBlock(ConversionPatternRewriter &rewriter, const TypeConverter *converter, Operation *branchOp, Block *block, TypeRange expectedTypes) { assert(converter && "expected non-null type converter"); assert(!block->isEntryBlock() && "entry blocks have no predecessors"); // There is nothing to do if the types already match. if (block->getArgumentTypes() == expectedTypes) return block; // Compute the new block argument types and convert the block. std::optional conversion = converter->convertBlockSignature(block); if (!conversion) return rewriter.notifyMatchFailure(branchOp, "could not compute block signature"); if (expectedTypes != conversion->getConvertedTypes()) return rewriter.notifyMatchFailure( branchOp, "mismatch between adaptor operand types and computed block signature"); return rewriter.applySignatureConversion(block, *conversion, converter); } /// Convert the destination block signature (if necessary) and lower the branch /// op to llvm.br. struct BranchOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { FailureOr convertedBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), TypeRange(adaptor.getOperands())); if (failed(convertedBlock)) return failure(); Operation *newOp = rewriter.replaceOpWithNewOp( op, adaptor.getOperands(), *convertedBlock); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. newOp->setAttrs(op->getAttrDictionary()); return success(); } }; /// Convert the destination block signatures (if necessary) and lower the /// branch op to llvm.cond_br. struct CondBranchOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(cf::CondBranchOp op, typename cf::CondBranchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { FailureOr convertedTrueBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), TypeRange(adaptor.getTrueDestOperands())); if (failed(convertedTrueBlock)) return failure(); FailureOr convertedFalseBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), TypeRange(adaptor.getFalseDestOperands())); if (failed(convertedFalseBlock)) return failure(); Operation *newOp = rewriter.replaceOpWithNewOp( op, adaptor.getCondition(), *convertedTrueBlock, adaptor.getTrueDestOperands(), *convertedFalseBlock, adaptor.getFalseDestOperands()); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. newOp->setAttrs(op->getAttrDictionary()); return success(); } }; /// Convert the destination block signatures (if necessary) and lower the /// switch op to llvm.switch. struct SwitchOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Get or convert default block. FailureOr convertedDefaultBlock = getConvertedBlock( rewriter, getTypeConverter(), op, op.getDefaultDestination(), TypeRange(adaptor.getDefaultOperands())); if (failed(convertedDefaultBlock)) return failure(); // Get or convert all case blocks. SmallVector caseDestinations; SmallVector caseOperands = adaptor.getCaseOperands(); for (auto it : llvm::enumerate(op.getCaseDestinations())) { Block *b = it.value(); FailureOr convertedBlock = getConvertedBlock(rewriter, getTypeConverter(), op, b, TypeRange(caseOperands[it.index()])); if (failed(convertedBlock)) return failure(); caseDestinations.push_back(*convertedBlock); } rewriter.replaceOpWithNewOp( op, adaptor.getFlag(), *convertedDefaultBlock, adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(), caseDestinations, caseOperands); return success(); } }; } // namespace void mlir::cf::populateControlFlowToLLVMConversionPatterns( const LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< BranchOpLowering, CondBranchOpLowering, SwitchOpLowering>(converter); // clang-format on } void mlir::cf::populateAssertToLLVMConversionPattern( const LLVMTypeConverter &converter, RewritePatternSet &patterns, bool abortOnFailure) { patterns.add(converter, abortOnFailure); } //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct ConvertControlFlowToLLVM : public impl::ConvertControlFlowToLLVMPassBase { using Base::Base; /// Run the dialect converter on the module. void runOnOperation() override { MLIRContext *ctx = &getContext(); LLVMConversionTarget target(*ctx); // This pass lowers only CF dialect ops, but it also modifies block // signatures inside other ops. These ops should be treated as legal. They // are lowered by other passes. target.markUnknownOpDynamicallyLegal([&](Operation *op) { return op->getDialect() != ctx->getLoadedDialect(); }); LowerToLLVMOptions options(ctx); if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter converter(ctx, options); RewritePatternSet patterns(ctx); mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace //===----------------------------------------------------------------------===// // ConvertToLLVMPatternInterface implementation //===----------------------------------------------------------------------===// namespace { /// Implement the interface to convert MemRef to LLVM. struct ControlFlowToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; void loadDependentDialects(MLIRContext *context) const final { context->loadDialect(); } /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); } }; } // namespace void mlir::cf::registerConvertControlFlowToLLVMInterface( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { dialect->addInterfaces(); }); }