xref: /llvm-project/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (revision 599c73990532333e62edf8ba19a5302b543f976f)
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 
16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
20 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/StringRef.h"
29 #include <functional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 #define PASS_NAME "convert-cf-to-llvm"
39 
40 namespace {
41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
42 /// assertion is violated and has no effect otherwise. The failure message is
43 /// ignored by the default lowering but should be propagated by any custom
44 /// lowering.
45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46   explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
47                             bool abortOnFailedAssert = true)
48       : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49         abortOnFailedAssert(abortOnFailedAssert) {}
50 
51   LogicalResult
52   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53                   ConversionPatternRewriter &rewriter) const override {
54     auto loc = op.getLoc();
55     auto module = op->getParentOfType<ModuleOp>();
56 
57     // Split block at `assert` operation.
58     Block *opBlock = rewriter.getInsertionBlock();
59     auto opPosition = rewriter.getInsertionPoint();
60     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
61 
62     // Failed block: Generate IR to print the message and call `abort`.
63     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64     LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
65                              *getTypeConverter(), /*addNewLine=*/false,
66                              /*runtimeFunctionName=*/"puts");
67     if (abortOnFailedAssert) {
68       // Insert the `abort` declaration if necessary.
69       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
70       if (!abortFunc) {
71         OpBuilder::InsertionGuard guard(rewriter);
72         rewriter.setInsertionPointToStart(module.getBody());
73         auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
74         abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
75                                                       "abort", abortFuncTy);
76       }
77       rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
78       rewriter.create<LLVM::UnreachableOp>(loc);
79     } else {
80       rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
81     }
82 
83     // Generate assertion test.
84     rewriter.setInsertionPointToEnd(opBlock);
85     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
86         op, adaptor.getArg(), continuationBlock, failureBlock);
87 
88     return success();
89   }
90 
91 private:
92   /// If set to `false`, messages are printed but program execution continues.
93   /// This is useful for testing asserts.
94   bool abortOnFailedAssert = true;
95 };
96 
97 /// Helper function for converting branch ops. This function converts the
98 /// signature of the given block. If the new block signature is different from
99 /// `expectedTypes`, returns "failure".
100 static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
101                                             const TypeConverter *converter,
102                                             Operation *branchOp, Block *block,
103                                             TypeRange expectedTypes) {
104   assert(converter && "expected non-null type converter");
105   assert(!block->isEntryBlock() && "entry blocks have no predecessors");
106 
107   // There is nothing to do if the types already match.
108   if (block->getArgumentTypes() == expectedTypes)
109     return block;
110 
111   // Compute the new block argument types and convert the block.
112   std::optional<TypeConverter::SignatureConversion> conversion =
113       converter->convertBlockSignature(block);
114   if (!conversion)
115     return rewriter.notifyMatchFailure(branchOp,
116                                        "could not compute block signature");
117   if (expectedTypes != conversion->getConvertedTypes())
118     return rewriter.notifyMatchFailure(
119         branchOp,
120         "mismatch between adaptor operand types and computed block signature");
121   return rewriter.applySignatureConversion(block, *conversion, converter);
122 }
123 
124 /// Convert the destination block signature (if necessary) and lower the branch
125 /// op to llvm.br.
126 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
127   using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
128 
129   LogicalResult
130   matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
131                   ConversionPatternRewriter &rewriter) const override {
132     FailureOr<Block *> convertedBlock =
133         getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
134                           TypeRange(adaptor.getOperands()));
135     if (failed(convertedBlock))
136       return failure();
137     Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
138         op, adaptor.getOperands(), *convertedBlock);
139     // TODO: We should not just forward all attributes like that. But there are
140     // existing Flang tests that depend on this behavior.
141     newOp->setAttrs(op->getAttrDictionary());
142     return success();
143   }
144 };
145 
146 /// Convert the destination block signatures (if necessary) and lower the
147 /// branch op to llvm.cond_br.
148 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
149   using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
150 
151   LogicalResult
152   matchAndRewrite(cf::CondBranchOp op,
153                   typename cf::CondBranchOp::Adaptor adaptor,
154                   ConversionPatternRewriter &rewriter) const override {
155     FailureOr<Block *> convertedTrueBlock =
156         getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
157                           TypeRange(adaptor.getTrueDestOperands()));
158     if (failed(convertedTrueBlock))
159       return failure();
160     FailureOr<Block *> convertedFalseBlock =
161         getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
162                           TypeRange(adaptor.getFalseDestOperands()));
163     if (failed(convertedFalseBlock))
164       return failure();
165     Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
166         op, adaptor.getCondition(), *convertedTrueBlock,
167         adaptor.getTrueDestOperands(), *convertedFalseBlock,
168         adaptor.getFalseDestOperands());
169     // TODO: We should not just forward all attributes like that. But there are
170     // existing Flang tests that depend on this behavior.
171     newOp->setAttrs(op->getAttrDictionary());
172     return success();
173   }
174 };
175 
176 /// Convert the destination block signatures (if necessary) and lower the
177 /// switch op to llvm.switch.
178 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
179   using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
180 
181   LogicalResult
182   matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
183                   ConversionPatternRewriter &rewriter) const override {
184     // Get or convert default block.
185     FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
186         rewriter, getTypeConverter(), op, op.getDefaultDestination(),
187         TypeRange(adaptor.getDefaultOperands()));
188     if (failed(convertedDefaultBlock))
189       return failure();
190 
191     // Get or convert all case blocks.
192     SmallVector<Block *> caseDestinations;
193     SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
194     for (auto it : llvm::enumerate(op.getCaseDestinations())) {
195       Block *b = it.value();
196       FailureOr<Block *> convertedBlock =
197           getConvertedBlock(rewriter, getTypeConverter(), op, b,
198                             TypeRange(caseOperands[it.index()]));
199       if (failed(convertedBlock))
200         return failure();
201       caseDestinations.push_back(*convertedBlock);
202     }
203 
204     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
205         op, adaptor.getFlag(), *convertedDefaultBlock,
206         adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
207         caseDestinations, caseOperands);
208     return success();
209   }
210 };
211 
212 } // namespace
213 
214 void mlir::cf::populateControlFlowToLLVMConversionPatterns(
215     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
216   // clang-format off
217   patterns.add<
218       BranchOpLowering,
219       CondBranchOpLowering,
220       SwitchOpLowering>(converter);
221   // clang-format on
222 }
223 
224 void mlir::cf::populateAssertToLLVMConversionPattern(
225     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
226     bool abortOnFailure) {
227   patterns.add<AssertOpLowering>(converter, abortOnFailure);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // Pass Definition
232 //===----------------------------------------------------------------------===//
233 
234 namespace {
235 /// A pass converting MLIR operations into the LLVM IR dialect.
236 struct ConvertControlFlowToLLVM
237     : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
238 
239   using Base::Base;
240 
241   /// Run the dialect converter on the module.
242   void runOnOperation() override {
243     MLIRContext *ctx = &getContext();
244     LLVMConversionTarget target(*ctx);
245     // This pass lowers only CF dialect ops, but it also modifies block
246     // signatures inside other ops. These ops should be treated as legal. They
247     // are lowered by other passes.
248     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
249       return op->getDialect() !=
250              ctx->getLoadedDialect<cf::ControlFlowDialect>();
251     });
252 
253     LowerToLLVMOptions options(ctx);
254     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
255       options.overrideIndexBitwidth(indexBitwidth);
256 
257     LLVMTypeConverter converter(ctx, options);
258     RewritePatternSet patterns(ctx);
259     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
260     mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
261 
262     if (failed(applyPartialConversion(getOperation(), target,
263                                       std::move(patterns))))
264       signalPassFailure();
265   }
266 };
267 } // namespace
268 
269 //===----------------------------------------------------------------------===//
270 // ConvertToLLVMPatternInterface implementation
271 //===----------------------------------------------------------------------===//
272 
273 namespace {
274 /// Implement the interface to convert MemRef to LLVM.
275 struct ControlFlowToLLVMDialectInterface
276     : public ConvertToLLVMPatternInterface {
277   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
278   void loadDependentDialects(MLIRContext *context) const final {
279     context->loadDialect<LLVM::LLVMDialect>();
280   }
281 
282   /// Hook for derived dialect interface to provide conversion patterns
283   /// and mark dialect legal for the conversion target.
284   void populateConvertToLLVMConversionPatterns(
285       ConversionTarget &target, LLVMTypeConverter &typeConverter,
286       RewritePatternSet &patterns) const final {
287     mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
288                                                           patterns);
289     mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
290   }
291 };
292 } // namespace
293 
294 void mlir::cf::registerConvertControlFlowToLLVMInterface(
295     DialectRegistry &registry) {
296   registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
297     dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
298   });
299 }
300