xref: /llvm-project/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (revision 4fed3d374dfca82d0cb32bb444985ece04438376)
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 
16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
20 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/StringRef.h"
29 #include <functional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 #define PASS_NAME "convert-cf-to-llvm"
39 
40 namespace {
41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
42 /// assertion is violated and has no effect otherwise. The failure message is
43 /// ignored by the default lowering but should be propagated by any custom
44 /// lowering.
45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46   explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
47                             bool abortOnFailedAssert = true)
48       : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49         abortOnFailedAssert(abortOnFailedAssert) {}
50 
51   LogicalResult
52   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53                   ConversionPatternRewriter &rewriter) const override {
54     auto loc = op.getLoc();
55     auto module = op->getParentOfType<ModuleOp>();
56 
57     // Split block at `assert` operation.
58     Block *opBlock = rewriter.getInsertionBlock();
59     auto opPosition = rewriter.getInsertionPoint();
60     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
61 
62     // Failed block: Generate IR to print the message and call `abort`.
63     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64     LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
65                              *getTypeConverter(), /*addNewLine=*/false,
66                              /*runtimeFunctionName=*/"puts");
67     if (abortOnFailedAssert) {
68       // Insert the `abort` declaration if necessary.
69       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
70       if (!abortFunc) {
71         OpBuilder::InsertionGuard guard(rewriter);
72         rewriter.setInsertionPointToStart(module.getBody());
73         auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
74         abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
75                                                       "abort", abortFuncTy);
76       }
77       rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
78       rewriter.create<LLVM::UnreachableOp>(loc);
79     } else {
80       rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
81     }
82 
83     // Generate assertion test.
84     rewriter.setInsertionPointToEnd(opBlock);
85     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
86         op, adaptor.getArg(), continuationBlock, failureBlock);
87 
88     return success();
89   }
90 
91 private:
92   /// If set to `false`, messages are printed but program execution continues.
93   /// This is useful for testing asserts.
94   bool abortOnFailedAssert = true;
95 };
96 
97 /// The cf->LLVM lowerings for branching ops require that the blocks they jump
98 /// to first have updated types which should be handled by a pattern operating
99 /// on the parent op.
100 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
101                                           ValueRange operands,
102                                           ValueRange blockArgs, Location loc,
103                                           llvm::StringRef messagePrefix) {
104   for (const auto &idxAndTypes :
105        llvm::enumerate(llvm::zip(blockArgs, operands))) {
106     int64_t i = idxAndTypes.index();
107     Value argValue =
108         rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
109     Type operandType = std::get<1>(idxAndTypes.value()).getType();
110     // In the case of an invalid jump, the block argument will have been
111     // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112     // there might still be a no-op conversion cast with both types being equal.
113     // Consider both of these details to see if the jump would be invalid.
114     if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115             argValue.getDefiningOp())) {
116       if (op.getOperandTypes().front() != operandType) {
117         return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
118           diag << messagePrefix;
119           diag << "mismatched types from operand # " << i << " ";
120           diag << operandType;
121           diag << " not compatible with destination block argument type ";
122           diag << op.getOperandTypes().front();
123           diag << " which should be converted with the parent op.";
124         });
125       }
126     }
127   }
128   return success();
129 }
130 
131 /// Ensure that all block types were updated and then create an LLVM::BrOp
132 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
133   using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
134 
135   LogicalResult
136   matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137                   ConversionPatternRewriter &rewriter) const override {
138     if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
139                                     op.getSuccessor()->getArguments(),
140                                     op.getLoc(),
141                                     /*messagePrefix=*/"")))
142       return failure();
143 
144     rewriter.replaceOpWithNewOp<LLVM::BrOp>(
145         op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
146     return success();
147   }
148 };
149 
150 /// Ensure that all block types were updated and then create an LLVM::CondBrOp
151 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
152   using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
153 
154   LogicalResult
155   matchAndRewrite(cf::CondBranchOp op,
156                   typename cf::CondBranchOp::Adaptor adaptor,
157                   ConversionPatternRewriter &rewriter) const override {
158     if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
159                                     op.getFalseDest()->getArguments(),
160                                     op.getLoc(), "in false case branch ")))
161       return failure();
162     if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
163                                     op.getTrueDest()->getArguments(),
164                                     op.getLoc(), "in true case branch ")))
165       return failure();
166 
167     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
168         op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
169     return success();
170   }
171 };
172 
173 /// Ensure that all block types were updated and then create an LLVM::SwitchOp
174 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
175   using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
176 
177   LogicalResult
178   matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179                   ConversionPatternRewriter &rewriter) const override {
180     if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
181                                     op.getDefaultDestination()->getArguments(),
182                                     op.getLoc(), "in switch default case ")))
183       return failure();
184 
185     for (const auto &i : llvm::enumerate(
186              llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
187       if (failed(verifyMatchingValues(
188               rewriter, std::get<0>(i.value()),
189               std::get<1>(i.value())->getArguments(), op.getLoc(),
190               "in switch case " + std::to_string(i.index()) + " "))) {
191         return failure();
192       }
193     }
194 
195     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
196         op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
197     return success();
198   }
199 };
200 
201 } // namespace
202 
203 void mlir::cf::populateControlFlowToLLVMConversionPatterns(
204     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
205   // clang-format off
206   patterns.add<
207       AssertOpLowering,
208       BranchOpLowering,
209       CondBranchOpLowering,
210       SwitchOpLowering>(converter);
211   // clang-format on
212 }
213 
214 void mlir::cf::populateAssertToLLVMConversionPattern(
215     LLVMTypeConverter &converter, RewritePatternSet &patterns,
216     bool abortOnFailure) {
217   patterns.add<AssertOpLowering>(converter, abortOnFailure);
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // Pass Definition
222 //===----------------------------------------------------------------------===//
223 
224 namespace {
225 /// A pass converting MLIR operations into the LLVM IR dialect.
226 struct ConvertControlFlowToLLVM
227     : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
228 
229   using Base::Base;
230 
231   /// Run the dialect converter on the module.
232   void runOnOperation() override {
233     LLVMConversionTarget target(getContext());
234     RewritePatternSet patterns(&getContext());
235 
236     LowerToLLVMOptions options(&getContext());
237     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
238       options.overrideIndexBitwidth(indexBitwidth);
239 
240     LLVMTypeConverter converter(&getContext(), options);
241     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
242 
243     if (failed(applyPartialConversion(getOperation(), target,
244                                       std::move(patterns))))
245       signalPassFailure();
246   }
247 };
248 } // namespace
249 
250 //===----------------------------------------------------------------------===//
251 // ConvertToLLVMPatternInterface implementation
252 //===----------------------------------------------------------------------===//
253 
254 namespace {
255 /// Implement the interface to convert MemRef to LLVM.
256 struct ControlFlowToLLVMDialectInterface
257     : public ConvertToLLVMPatternInterface {
258   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
259   void loadDependentDialects(MLIRContext *context) const final {
260     context->loadDialect<LLVM::LLVMDialect>();
261   }
262 
263   /// Hook for derived dialect interface to provide conversion patterns
264   /// and mark dialect legal for the conversion target.
265   void populateConvertToLLVMConversionPatterns(
266       ConversionTarget &target, LLVMTypeConverter &typeConverter,
267       RewritePatternSet &patterns) const final {
268     mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
269                                                           patterns);
270   }
271 };
272 } // namespace
273 
274 void mlir::cf::registerConvertControlFlowToLLVMInterface(
275     DialectRegistry &registry) {
276   registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
277     dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
278   });
279 }
280