xref: /llvm-project/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (revision 23aa5a744666b281af807b1f598f517bf0d597cb)
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 #include "../PassDetail.h"
16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include <functional>
26 
27 using namespace mlir;
28 
29 #define PASS_NAME "convert-cf-to-llvm"
30 
31 namespace {
32 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
33 /// assertion is violated and has no effect otherwise. The failure message is
34 /// ignored by the default lowering but should be propagated by any custom
35 /// lowering.
36 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
37   using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
38 
39   LogicalResult
40   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
41                   ConversionPatternRewriter &rewriter) const override {
42     auto loc = op.getLoc();
43 
44     // Insert the `abort` declaration if necessary.
45     auto module = op->getParentOfType<ModuleOp>();
46     auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
47     if (!abortFunc) {
48       OpBuilder::InsertionGuard guard(rewriter);
49       rewriter.setInsertionPointToStart(module.getBody());
50       auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
51       abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
52                                                     "abort", abortFuncTy);
53     }
54 
55     // Split block at `assert` operation.
56     Block *opBlock = rewriter.getInsertionBlock();
57     auto opPosition = rewriter.getInsertionPoint();
58     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
59 
60     // Generate IR to call `abort`.
61     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
62     rewriter.create<LLVM::CallOp>(loc, abortFunc, llvm::None);
63     rewriter.create<LLVM::UnreachableOp>(loc);
64 
65     // Generate assertion test.
66     rewriter.setInsertionPointToEnd(opBlock);
67     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
68         op, adaptor.getArg(), continuationBlock, failureBlock);
69 
70     return success();
71   }
72 };
73 
74 // Base class for LLVM IR lowering terminator operations with successors.
75 template <typename SourceOp, typename TargetOp>
76 struct OneToOneLLVMTerminatorLowering
77     : public ConvertOpToLLVMPattern<SourceOp> {
78   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
79   using Base = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;
80 
81   LogicalResult
82   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
83                   ConversionPatternRewriter &rewriter) const override {
84     rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getOperands(),
85                                           op->getSuccessors(), op->getAttrs());
86     return success();
87   }
88 };
89 
90 // FIXME: this should be tablegen'ed as well.
91 struct BranchOpLowering
92     : public OneToOneLLVMTerminatorLowering<cf::BranchOp, LLVM::BrOp> {
93   using Base::Base;
94 };
95 struct CondBranchOpLowering
96     : public OneToOneLLVMTerminatorLowering<cf::CondBranchOp, LLVM::CondBrOp> {
97   using Base::Base;
98 };
99 struct SwitchOpLowering
100     : public OneToOneLLVMTerminatorLowering<cf::SwitchOp, LLVM::SwitchOp> {
101   using Base::Base;
102 };
103 
104 } // namespace
105 
106 void mlir::cf::populateControlFlowToLLVMConversionPatterns(
107     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
108   // clang-format off
109   patterns.add<
110       AssertOpLowering,
111       BranchOpLowering,
112       CondBranchOpLowering,
113       SwitchOpLowering>(converter);
114   // clang-format on
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // Pass Definition
119 //===----------------------------------------------------------------------===//
120 
121 namespace {
122 /// A pass converting MLIR operations into the LLVM IR dialect.
123 struct ConvertControlFlowToLLVM
124     : public ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> {
125   ConvertControlFlowToLLVM() = default;
126 
127   /// Run the dialect converter on the module.
128   void runOnOperation() override {
129     LLVMConversionTarget target(getContext());
130     RewritePatternSet patterns(&getContext());
131 
132     LowerToLLVMOptions options(&getContext());
133     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
134       options.overrideIndexBitwidth(indexBitwidth);
135 
136     LLVMTypeConverter converter(&getContext(), options);
137     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
138 
139     if (failed(applyPartialConversion(getOperation(), target,
140                                       std::move(patterns))))
141       signalPassFailure();
142   }
143 };
144 } // namespace
145 
146 std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() {
147   return std::make_unique<ConvertControlFlowToLLVM>();
148 }
149