xref: /llvm-project/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 
16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
20 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/IR/BuiltinOps.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/StringRef.h"
29 #include <functional>
30 
31 namespace mlir {
32 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33 #include "mlir/Conversion/Passes.h.inc"
34 } // namespace mlir
35 
36 using namespace mlir;
37 
38 #define PASS_NAME "convert-cf-to-llvm"
39 
40 namespace {
41 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
42 /// assertion is violated and has no effect otherwise. The failure message is
43 /// ignored by the default lowering but should be propagated by any custom
44 /// lowering.
45 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46   explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
47                             bool abortOnFailedAssert = true)
48       : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49         abortOnFailedAssert(abortOnFailedAssert) {}
50 
51   LogicalResult
52   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53                   ConversionPatternRewriter &rewriter) const override {
54     auto loc = op.getLoc();
55     auto module = op->getParentOfType<ModuleOp>();
56 
57     // Split block at `assert` operation.
58     Block *opBlock = rewriter.getInsertionBlock();
59     auto opPosition = rewriter.getInsertionPoint();
60     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
61 
62     // Failed block: Generate IR to print the message and call `abort`.
63     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64     auto createResult = LLVM::createPrintStrCall(
65         rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
66         /*addNewLine=*/false,
67         /*runtimeFunctionName=*/"puts");
68     if (createResult.failed())
69       return failure();
70 
71     if (abortOnFailedAssert) {
72       // Insert the `abort` declaration if necessary.
73       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
74       if (!abortFunc) {
75         OpBuilder::InsertionGuard guard(rewriter);
76         rewriter.setInsertionPointToStart(module.getBody());
77         auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
78         abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
79                                                       "abort", abortFuncTy);
80       }
81       rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
82       rewriter.create<LLVM::UnreachableOp>(loc);
83     } else {
84       rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
85     }
86 
87     // Generate assertion test.
88     rewriter.setInsertionPointToEnd(opBlock);
89     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
90         op, adaptor.getArg(), continuationBlock, failureBlock);
91 
92     return success();
93   }
94 
95 private:
96   /// If set to `false`, messages are printed but program execution continues.
97   /// This is useful for testing asserts.
98   bool abortOnFailedAssert = true;
99 };
100 
101 /// Helper function for converting branch ops. This function converts the
102 /// signature of the given block. If the new block signature is different from
103 /// `expectedTypes`, returns "failure".
104 static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
105                                             const TypeConverter *converter,
106                                             Operation *branchOp, Block *block,
107                                             TypeRange expectedTypes) {
108   assert(converter && "expected non-null type converter");
109   assert(!block->isEntryBlock() && "entry blocks have no predecessors");
110 
111   // There is nothing to do if the types already match.
112   if (block->getArgumentTypes() == expectedTypes)
113     return block;
114 
115   // Compute the new block argument types and convert the block.
116   std::optional<TypeConverter::SignatureConversion> conversion =
117       converter->convertBlockSignature(block);
118   if (!conversion)
119     return rewriter.notifyMatchFailure(branchOp,
120                                        "could not compute block signature");
121   if (expectedTypes != conversion->getConvertedTypes())
122     return rewriter.notifyMatchFailure(
123         branchOp,
124         "mismatch between adaptor operand types and computed block signature");
125   return rewriter.applySignatureConversion(block, *conversion, converter);
126 }
127 
128 /// Convert the destination block signature (if necessary) and lower the branch
129 /// op to llvm.br.
130 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
131   using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
132 
133   LogicalResult
134   matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
135                   ConversionPatternRewriter &rewriter) const override {
136     FailureOr<Block *> convertedBlock =
137         getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
138                           TypeRange(adaptor.getOperands()));
139     if (failed(convertedBlock))
140       return failure();
141     Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
142         op, adaptor.getOperands(), *convertedBlock);
143     // TODO: We should not just forward all attributes like that. But there are
144     // existing Flang tests that depend on this behavior.
145     newOp->setAttrs(op->getAttrDictionary());
146     return success();
147   }
148 };
149 
150 /// Convert the destination block signatures (if necessary) and lower the
151 /// branch op to llvm.cond_br.
152 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
153   using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
154 
155   LogicalResult
156   matchAndRewrite(cf::CondBranchOp op,
157                   typename cf::CondBranchOp::Adaptor adaptor,
158                   ConversionPatternRewriter &rewriter) const override {
159     FailureOr<Block *> convertedTrueBlock =
160         getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
161                           TypeRange(adaptor.getTrueDestOperands()));
162     if (failed(convertedTrueBlock))
163       return failure();
164     FailureOr<Block *> convertedFalseBlock =
165         getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
166                           TypeRange(adaptor.getFalseDestOperands()));
167     if (failed(convertedFalseBlock))
168       return failure();
169     Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
170         op, adaptor.getCondition(), *convertedTrueBlock,
171         adaptor.getTrueDestOperands(), *convertedFalseBlock,
172         adaptor.getFalseDestOperands());
173     // TODO: We should not just forward all attributes like that. But there are
174     // existing Flang tests that depend on this behavior.
175     newOp->setAttrs(op->getAttrDictionary());
176     return success();
177   }
178 };
179 
180 /// Convert the destination block signatures (if necessary) and lower the
181 /// switch op to llvm.switch.
182 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
183   using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
184 
185   LogicalResult
186   matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
187                   ConversionPatternRewriter &rewriter) const override {
188     // Get or convert default block.
189     FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
190         rewriter, getTypeConverter(), op, op.getDefaultDestination(),
191         TypeRange(adaptor.getDefaultOperands()));
192     if (failed(convertedDefaultBlock))
193       return failure();
194 
195     // Get or convert all case blocks.
196     SmallVector<Block *> caseDestinations;
197     SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
198     for (auto it : llvm::enumerate(op.getCaseDestinations())) {
199       Block *b = it.value();
200       FailureOr<Block *> convertedBlock =
201           getConvertedBlock(rewriter, getTypeConverter(), op, b,
202                             TypeRange(caseOperands[it.index()]));
203       if (failed(convertedBlock))
204         return failure();
205       caseDestinations.push_back(*convertedBlock);
206     }
207 
208     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
209         op, adaptor.getFlag(), *convertedDefaultBlock,
210         adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
211         caseDestinations, caseOperands);
212     return success();
213   }
214 };
215 
216 } // namespace
217 
218 void mlir::cf::populateControlFlowToLLVMConversionPatterns(
219     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
220   // clang-format off
221   patterns.add<
222       BranchOpLowering,
223       CondBranchOpLowering,
224       SwitchOpLowering>(converter);
225   // clang-format on
226 }
227 
228 void mlir::cf::populateAssertToLLVMConversionPattern(
229     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
230     bool abortOnFailure) {
231   patterns.add<AssertOpLowering>(converter, abortOnFailure);
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // Pass Definition
236 //===----------------------------------------------------------------------===//
237 
238 namespace {
239 /// A pass converting MLIR operations into the LLVM IR dialect.
240 struct ConvertControlFlowToLLVM
241     : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
242 
243   using Base::Base;
244 
245   /// Run the dialect converter on the module.
246   void runOnOperation() override {
247     MLIRContext *ctx = &getContext();
248     LLVMConversionTarget target(*ctx);
249     // This pass lowers only CF dialect ops, but it also modifies block
250     // signatures inside other ops. These ops should be treated as legal. They
251     // are lowered by other passes.
252     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
253       return op->getDialect() !=
254              ctx->getLoadedDialect<cf::ControlFlowDialect>();
255     });
256 
257     LowerToLLVMOptions options(ctx);
258     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
259       options.overrideIndexBitwidth(indexBitwidth);
260 
261     LLVMTypeConverter converter(ctx, options);
262     RewritePatternSet patterns(ctx);
263     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
264     mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
265 
266     if (failed(applyPartialConversion(getOperation(), target,
267                                       std::move(patterns))))
268       signalPassFailure();
269   }
270 };
271 } // namespace
272 
273 //===----------------------------------------------------------------------===//
274 // ConvertToLLVMPatternInterface implementation
275 //===----------------------------------------------------------------------===//
276 
277 namespace {
278 /// Implement the interface to convert MemRef to LLVM.
279 struct ControlFlowToLLVMDialectInterface
280     : public ConvertToLLVMPatternInterface {
281   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
282   void loadDependentDialects(MLIRContext *context) const final {
283     context->loadDialect<LLVM::LLVMDialect>();
284   }
285 
286   /// Hook for derived dialect interface to provide conversion patterns
287   /// and mark dialect legal for the conversion target.
288   void populateConvertToLLVMConversionPatterns(
289       ConversionTarget &target, LLVMTypeConverter &typeConverter,
290       RewritePatternSet &patterns) const final {
291     mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
292                                                           patterns);
293     mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
294   }
295 };
296 } // namespace
297 
298 void mlir::cf::registerConvertControlFlowToLLVMInterface(
299     DialectRegistry &registry) {
300   registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
301     dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
302   });
303 }
304