xref: /llvm-project/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (revision e84f6b6a88c1222d512edf0644c8f869dd12b8ef)
1ace01605SRiver Riddle //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2ace01605SRiver Riddle //
3ace01605SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ace01605SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5ace01605SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ace01605SRiver Riddle //
7ace01605SRiver Riddle //===----------------------------------------------------------------------===//
8ace01605SRiver Riddle //
9ace01605SRiver Riddle // This file implements a pass to convert MLIR standard and builtin dialects
10ace01605SRiver Riddle // into the LLVM IR dialect.
11ace01605SRiver Riddle //
12ace01605SRiver Riddle //===----------------------------------------------------------------------===//
13ace01605SRiver Riddle 
14ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
1567d0d7acSMichele Scuttari 
16c4769ef5SMatthias Springer #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/Pattern.h"
193be3883eSBenjamin Maxwell #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
20ace01605SRiver Riddle #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
21ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
23ace01605SRiver Riddle #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24ace01605SRiver Riddle #include "mlir/IR/BuiltinOps.h"
25ace01605SRiver Riddle #include "mlir/IR/PatternMatch.h"
2667d0d7acSMichele Scuttari #include "mlir/Pass/Pass.h"
27ace01605SRiver Riddle #include "mlir/Transforms/DialectConversion.h"
28448adfeeSTres Popp #include "llvm/ADT/StringRef.h"
29ace01605SRiver Riddle #include <functional>
30ace01605SRiver Riddle 
3167d0d7acSMichele Scuttari namespace mlir {
32cd4ca2d7SMarkus Böck #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
3367d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
3467d0d7acSMichele Scuttari } // namespace mlir
3567d0d7acSMichele Scuttari 
36ace01605SRiver Riddle using namespace mlir;
37ace01605SRiver Riddle 
38ace01605SRiver Riddle #define PASS_NAME "convert-cf-to-llvm"
39ace01605SRiver Riddle 
40ace01605SRiver Riddle namespace {
4123aa5a74SRiver Riddle /// Lower `cf.assert`. The default lowering calls the `abort` function if the
42ace01605SRiver Riddle /// assertion is violated and has no effect otherwise. The failure message is
43ace01605SRiver Riddle /// ignored by the default lowering but should be propagated by any custom
44ace01605SRiver Riddle /// lowering.
45ace01605SRiver Riddle struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46206fad0eSMatthias Springer   explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
47325b58d5SMatthias Springer                             bool abortOnFailedAssert = true)
48325b58d5SMatthias Springer       : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49325b58d5SMatthias Springer         abortOnFailedAssert(abortOnFailedAssert) {}
50ace01605SRiver Riddle 
51ace01605SRiver Riddle   LogicalResult
52ace01605SRiver Riddle   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
54ace01605SRiver Riddle     auto loc = op.getLoc();
55ace01605SRiver Riddle     auto module = op->getParentOfType<ModuleOp>();
56325b58d5SMatthias Springer 
57325b58d5SMatthias Springer     // Split block at `assert` operation.
58325b58d5SMatthias Springer     Block *opBlock = rewriter.getInsertionBlock();
59325b58d5SMatthias Springer     auto opPosition = rewriter.getInsertionPoint();
60325b58d5SMatthias Springer     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
61325b58d5SMatthias Springer 
62325b58d5SMatthias Springer     // Failed block: Generate IR to print the message and call `abort`.
63325b58d5SMatthias Springer     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64*e84f6b6aSLuohao Wang     auto createResult = LLVM::createPrintStrCall(
65*e84f6b6aSLuohao Wang         rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
66*e84f6b6aSLuohao Wang         /*addNewLine=*/false,
673be3883eSBenjamin Maxwell         /*runtimeFunctionName=*/"puts");
68*e84f6b6aSLuohao Wang     if (createResult.failed())
69*e84f6b6aSLuohao Wang       return failure();
70*e84f6b6aSLuohao Wang 
71325b58d5SMatthias Springer     if (abortOnFailedAssert) {
72325b58d5SMatthias Springer       // Insert the `abort` declaration if necessary.
73ace01605SRiver Riddle       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
74ace01605SRiver Riddle       if (!abortFunc) {
75ace01605SRiver Riddle         OpBuilder::InsertionGuard guard(rewriter);
76ace01605SRiver Riddle         rewriter.setInsertionPointToStart(module.getBody());
77ace01605SRiver Riddle         auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
78ace01605SRiver Riddle         abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
79ace01605SRiver Riddle                                                       "abort", abortFuncTy);
80ace01605SRiver Riddle       }
811a36588eSKazu Hirata       rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
82ace01605SRiver Riddle       rewriter.create<LLVM::UnreachableOp>(loc);
83325b58d5SMatthias Springer     } else {
84325b58d5SMatthias Springer       rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
85325b58d5SMatthias Springer     }
86ace01605SRiver Riddle 
87ace01605SRiver Riddle     // Generate assertion test.
88ace01605SRiver Riddle     rewriter.setInsertionPointToEnd(opBlock);
89ace01605SRiver Riddle     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
90ace01605SRiver Riddle         op, adaptor.getArg(), continuationBlock, failureBlock);
91ace01605SRiver Riddle 
92ace01605SRiver Riddle     return success();
93ace01605SRiver Riddle   }
94325b58d5SMatthias Springer 
95325b58d5SMatthias Springer private:
96325b58d5SMatthias Springer   /// If set to `false`, messages are printed but program execution continues.
97325b58d5SMatthias Springer   /// This is useful for testing asserts.
98325b58d5SMatthias Springer   bool abortOnFailedAssert = true;
99ace01605SRiver Riddle };
100ace01605SRiver Riddle 
101eb6c4197SMatthias Springer /// Helper function for converting branch ops. This function converts the
102eb6c4197SMatthias Springer /// signature of the given block. If the new block signature is different from
103eb6c4197SMatthias Springer /// `expectedTypes`, returns "failure".
104eb6c4197SMatthias Springer static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
105eb6c4197SMatthias Springer                                             const TypeConverter *converter,
106eb6c4197SMatthias Springer                                             Operation *branchOp, Block *block,
107eb6c4197SMatthias Springer                                             TypeRange expectedTypes) {
108eb6c4197SMatthias Springer   assert(converter && "expected non-null type converter");
109eb6c4197SMatthias Springer   assert(!block->isEntryBlock() && "entry blocks have no predecessors");
110eb6c4197SMatthias Springer 
111eb6c4197SMatthias Springer   // There is nothing to do if the types already match.
112eb6c4197SMatthias Springer   if (block->getArgumentTypes() == expectedTypes)
113eb6c4197SMatthias Springer     return block;
114eb6c4197SMatthias Springer 
115eb6c4197SMatthias Springer   // Compute the new block argument types and convert the block.
116eb6c4197SMatthias Springer   std::optional<TypeConverter::SignatureConversion> conversion =
117eb6c4197SMatthias Springer       converter->convertBlockSignature(block);
118eb6c4197SMatthias Springer   if (!conversion)
119eb6c4197SMatthias Springer     return rewriter.notifyMatchFailure(branchOp,
120eb6c4197SMatthias Springer                                        "could not compute block signature");
121eb6c4197SMatthias Springer   if (expectedTypes != conversion->getConvertedTypes())
122eb6c4197SMatthias Springer     return rewriter.notifyMatchFailure(
123eb6c4197SMatthias Springer         branchOp,
124eb6c4197SMatthias Springer         "mismatch between adaptor operand types and computed block signature");
125eb6c4197SMatthias Springer   return rewriter.applySignatureConversion(block, *conversion, converter);
126448adfeeSTres Popp }
127448adfeeSTres Popp 
128eb6c4197SMatthias Springer /// Convert the destination block signature (if necessary) and lower the branch
129eb6c4197SMatthias Springer /// op to llvm.br.
130448adfeeSTres Popp struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
131448adfeeSTres Popp   using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
132ace01605SRiver Riddle 
133ace01605SRiver Riddle   LogicalResult
134448adfeeSTres Popp   matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
135ace01605SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
136eb6c4197SMatthias Springer     FailureOr<Block *> convertedBlock =
137eb6c4197SMatthias Springer         getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
138eb6c4197SMatthias Springer                           TypeRange(adaptor.getOperands()));
139eb6c4197SMatthias Springer     if (failed(convertedBlock))
140448adfeeSTres Popp       return failure();
141eb6c4197SMatthias Springer     Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
142eb6c4197SMatthias Springer         op, adaptor.getOperands(), *convertedBlock);
143eb6c4197SMatthias Springer     // TODO: We should not just forward all attributes like that. But there are
144eb6c4197SMatthias Springer     // existing Flang tests that depend on this behavior.
145eb6c4197SMatthias Springer     newOp->setAttrs(op->getAttrDictionary());
146ace01605SRiver Riddle     return success();
147ace01605SRiver Riddle   }
148ace01605SRiver Riddle };
149ace01605SRiver Riddle 
150eb6c4197SMatthias Springer /// Convert the destination block signatures (if necessary) and lower the
151eb6c4197SMatthias Springer /// branch op to llvm.cond_br.
152448adfeeSTres Popp struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
153448adfeeSTres Popp   using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
154448adfeeSTres Popp 
155448adfeeSTres Popp   LogicalResult
156448adfeeSTres Popp   matchAndRewrite(cf::CondBranchOp op,
157448adfeeSTres Popp                   typename cf::CondBranchOp::Adaptor adaptor,
158448adfeeSTres Popp                   ConversionPatternRewriter &rewriter) const override {
159eb6c4197SMatthias Springer     FailureOr<Block *> convertedTrueBlock =
160eb6c4197SMatthias Springer         getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
161eb6c4197SMatthias Springer                           TypeRange(adaptor.getTrueDestOperands()));
162eb6c4197SMatthias Springer     if (failed(convertedTrueBlock))
163448adfeeSTres Popp       return failure();
164eb6c4197SMatthias Springer     FailureOr<Block *> convertedFalseBlock =
165eb6c4197SMatthias Springer         getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
166eb6c4197SMatthias Springer                           TypeRange(adaptor.getFalseDestOperands()));
167eb6c4197SMatthias Springer     if (failed(convertedFalseBlock))
168448adfeeSTres Popp       return failure();
169eb6c4197SMatthias Springer     Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
170eb6c4197SMatthias Springer         op, adaptor.getCondition(), *convertedTrueBlock,
171eb6c4197SMatthias Springer         adaptor.getTrueDestOperands(), *convertedFalseBlock,
172eb6c4197SMatthias Springer         adaptor.getFalseDestOperands());
173eb6c4197SMatthias Springer     // TODO: We should not just forward all attributes like that. But there are
174eb6c4197SMatthias Springer     // existing Flang tests that depend on this behavior.
175eb6c4197SMatthias Springer     newOp->setAttrs(op->getAttrDictionary());
176448adfeeSTres Popp     return success();
177448adfeeSTres Popp   }
178ace01605SRiver Riddle };
179448adfeeSTres Popp 
180eb6c4197SMatthias Springer /// Convert the destination block signatures (if necessary) and lower the
181eb6c4197SMatthias Springer /// switch op to llvm.switch.
182448adfeeSTres Popp struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
183448adfeeSTres Popp   using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
184448adfeeSTres Popp 
185448adfeeSTres Popp   LogicalResult
186448adfeeSTres Popp   matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
187448adfeeSTres Popp                   ConversionPatternRewriter &rewriter) const override {
188eb6c4197SMatthias Springer     // Get or convert default block.
189eb6c4197SMatthias Springer     FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
190eb6c4197SMatthias Springer         rewriter, getTypeConverter(), op, op.getDefaultDestination(),
191eb6c4197SMatthias Springer         TypeRange(adaptor.getDefaultOperands()));
192eb6c4197SMatthias Springer     if (failed(convertedDefaultBlock))
193448adfeeSTres Popp       return failure();
194448adfeeSTres Popp 
195eb6c4197SMatthias Springer     // Get or convert all case blocks.
196eb6c4197SMatthias Springer     SmallVector<Block *> caseDestinations;
197eb6c4197SMatthias Springer     SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
198eb6c4197SMatthias Springer     for (auto it : llvm::enumerate(op.getCaseDestinations())) {
199eb6c4197SMatthias Springer       Block *b = it.value();
200eb6c4197SMatthias Springer       FailureOr<Block *> convertedBlock =
201eb6c4197SMatthias Springer           getConvertedBlock(rewriter, getTypeConverter(), op, b,
202eb6c4197SMatthias Springer                             TypeRange(caseOperands[it.index()]));
203eb6c4197SMatthias Springer       if (failed(convertedBlock))
204448adfeeSTres Popp         return failure();
205eb6c4197SMatthias Springer       caseDestinations.push_back(*convertedBlock);
206448adfeeSTres Popp     }
207448adfeeSTres Popp 
208448adfeeSTres Popp     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
209eb6c4197SMatthias Springer         op, adaptor.getFlag(), *convertedDefaultBlock,
210eb6c4197SMatthias Springer         adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
211eb6c4197SMatthias Springer         caseDestinations, caseOperands);
212448adfeeSTres Popp     return success();
213448adfeeSTres Popp   }
214ace01605SRiver Riddle };
215ace01605SRiver Riddle 
216ace01605SRiver Riddle } // namespace
217ace01605SRiver Riddle 
218fcb0294bSAlex Zinenko void mlir::cf::populateControlFlowToLLVMConversionPatterns(
219206fad0eSMatthias Springer     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
220ace01605SRiver Riddle   // clang-format off
221ace01605SRiver Riddle   patterns.add<
222ace01605SRiver Riddle       BranchOpLowering,
223ace01605SRiver Riddle       CondBranchOpLowering,
224ace01605SRiver Riddle       SwitchOpLowering>(converter);
225ace01605SRiver Riddle   // clang-format on
226ace01605SRiver Riddle }
227ace01605SRiver Riddle 
228325b58d5SMatthias Springer void mlir::cf::populateAssertToLLVMConversionPattern(
229206fad0eSMatthias Springer     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
230325b58d5SMatthias Springer     bool abortOnFailure) {
231325b58d5SMatthias Springer   patterns.add<AssertOpLowering>(converter, abortOnFailure);
232325b58d5SMatthias Springer }
233325b58d5SMatthias Springer 
234ace01605SRiver Riddle //===----------------------------------------------------------------------===//
235ace01605SRiver Riddle // Pass Definition
236ace01605SRiver Riddle //===----------------------------------------------------------------------===//
237ace01605SRiver Riddle 
238ace01605SRiver Riddle namespace {
239ace01605SRiver Riddle /// A pass converting MLIR operations into the LLVM IR dialect.
240039b969bSMichele Scuttari struct ConvertControlFlowToLLVM
241cd4ca2d7SMarkus Böck     : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
242cd4ca2d7SMarkus Böck 
243cd4ca2d7SMarkus Böck   using Base::Base;
244ace01605SRiver Riddle 
245ace01605SRiver Riddle   /// Run the dialect converter on the module.
246ace01605SRiver Riddle   void runOnOperation() override {
247eb6c4197SMatthias Springer     MLIRContext *ctx = &getContext();
248eb6c4197SMatthias Springer     LLVMConversionTarget target(*ctx);
249eb6c4197SMatthias Springer     // This pass lowers only CF dialect ops, but it also modifies block
250eb6c4197SMatthias Springer     // signatures inside other ops. These ops should be treated as legal. They
251eb6c4197SMatthias Springer     // are lowered by other passes.
252eb6c4197SMatthias Springer     target.markUnknownOpDynamicallyLegal([&](Operation *op) {
253eb6c4197SMatthias Springer       return op->getDialect() !=
254eb6c4197SMatthias Springer              ctx->getLoadedDialect<cf::ControlFlowDialect>();
255eb6c4197SMatthias Springer     });
256ace01605SRiver Riddle 
257eb6c4197SMatthias Springer     LowerToLLVMOptions options(ctx);
258ace01605SRiver Riddle     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
259ace01605SRiver Riddle       options.overrideIndexBitwidth(indexBitwidth);
260ace01605SRiver Riddle 
261eb6c4197SMatthias Springer     LLVMTypeConverter converter(ctx, options);
262eb6c4197SMatthias Springer     RewritePatternSet patterns(ctx);
263fcb0294bSAlex Zinenko     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
264599c7399SMatthias Springer     mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
265fcb0294bSAlex Zinenko 
266ace01605SRiver Riddle     if (failed(applyPartialConversion(getOperation(), target,
267ace01605SRiver Riddle                                       std::move(patterns))))
268ace01605SRiver Riddle       signalPassFailure();
269ace01605SRiver Riddle   }
270ace01605SRiver Riddle };
271ace01605SRiver Riddle } // namespace
272c4769ef5SMatthias Springer 
273c4769ef5SMatthias Springer //===----------------------------------------------------------------------===//
274c4769ef5SMatthias Springer // ConvertToLLVMPatternInterface implementation
275c4769ef5SMatthias Springer //===----------------------------------------------------------------------===//
276c4769ef5SMatthias Springer 
277c4769ef5SMatthias Springer namespace {
278c4769ef5SMatthias Springer /// Implement the interface to convert MemRef to LLVM.
279c4769ef5SMatthias Springer struct ControlFlowToLLVMDialectInterface
280c4769ef5SMatthias Springer     : public ConvertToLLVMPatternInterface {
281c4769ef5SMatthias Springer   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
282c4769ef5SMatthias Springer   void loadDependentDialects(MLIRContext *context) const final {
283c4769ef5SMatthias Springer     context->loadDialect<LLVM::LLVMDialect>();
284c4769ef5SMatthias Springer   }
285c4769ef5SMatthias Springer 
286c4769ef5SMatthias Springer   /// Hook for derived dialect interface to provide conversion patterns
287c4769ef5SMatthias Springer   /// and mark dialect legal for the conversion target.
288c4769ef5SMatthias Springer   void populateConvertToLLVMConversionPatterns(
289c4769ef5SMatthias Springer       ConversionTarget &target, LLVMTypeConverter &typeConverter,
290c4769ef5SMatthias Springer       RewritePatternSet &patterns) const final {
291c4769ef5SMatthias Springer     mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
292c4769ef5SMatthias Springer                                                           patterns);
293599c7399SMatthias Springer     mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns);
294c4769ef5SMatthias Springer   }
295c4769ef5SMatthias Springer };
296c4769ef5SMatthias Springer } // namespace
297c4769ef5SMatthias Springer 
298c4769ef5SMatthias Springer void mlir::cf::registerConvertControlFlowToLLVMInterface(
299c4769ef5SMatthias Springer     DialectRegistry &registry) {
300c4769ef5SMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
301c4769ef5SMatthias Springer     dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
302c4769ef5SMatthias Springer   });
303c4769ef5SMatthias Springer }
304