xref: /llvm-project/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (revision ce254598b73b119c9463f5b7f4131559e276e844)
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 
16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
21 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/ADT/StringRef.h"
28 #include <functional>
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 
37 #define PASS_NAME "convert-cf-to-llvm"
38 
39 static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
40   std::string prefix = "assert_msg_";
41   int counter = 0;
42   while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
43     ++counter;
44   return prefix + std::to_string(counter);
45 }
46 
47 /// Generate IR that prints the given string to stderr.
48 static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
49                            StringRef msg,
50                            const LLVMTypeConverter &typeConverter) {
51   auto ip = builder.saveInsertionPoint();
52   builder.setInsertionPointToStart(moduleOp.getBody());
53   MLIRContext *ctx = builder.getContext();
54 
55   // Create a zero-terminated byte representation and allocate global symbol.
56   SmallVector<uint8_t> elementVals;
57   elementVals.append(msg.begin(), msg.end());
58   elementVals.push_back(0);
59   auto dataAttrType = RankedTensorType::get(
60       {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
61   auto dataAttr =
62       DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
63   auto arrayTy =
64       LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
65   std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
66   auto globalOp = builder.create<LLVM::GlobalOp>(
67       loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
68       dataAttr);
69 
70   // Emit call to `printStr` in runtime library.
71   builder.restoreInsertionPoint(ip);
72   auto msgAddr = builder.create<LLVM::AddressOfOp>(
73       loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
74   SmallVector<LLVM::GEPArg> indices(1, 0);
75   Value gep = builder.create<LLVM::GEPOp>(
76       loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
77       indices);
78   Operation *printer = LLVM::lookupOrCreatePrintStrFn(
79       moduleOp, typeConverter.useOpaquePointers());
80   builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
81                                gep);
82 }
83 
84 namespace {
85 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
86 /// assertion is violated and has no effect otherwise. The failure message is
87 /// ignored by the default lowering but should be propagated by any custom
88 /// lowering.
89 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
90   explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
91                             bool abortOnFailedAssert = true)
92       : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
93         abortOnFailedAssert(abortOnFailedAssert) {}
94 
95   LogicalResult
96   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
97                   ConversionPatternRewriter &rewriter) const override {
98     auto loc = op.getLoc();
99     auto module = op->getParentOfType<ModuleOp>();
100 
101     // Split block at `assert` operation.
102     Block *opBlock = rewriter.getInsertionBlock();
103     auto opPosition = rewriter.getInsertionPoint();
104     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
105 
106     // Failed block: Generate IR to print the message and call `abort`.
107     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
108     createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
109     if (abortOnFailedAssert) {
110       // Insert the `abort` declaration if necessary.
111       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
112       if (!abortFunc) {
113         OpBuilder::InsertionGuard guard(rewriter);
114         rewriter.setInsertionPointToStart(module.getBody());
115         auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
116         abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
117                                                       "abort", abortFuncTy);
118       }
119       rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
120       rewriter.create<LLVM::UnreachableOp>(loc);
121     } else {
122       rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
123     }
124 
125     // Generate assertion test.
126     rewriter.setInsertionPointToEnd(opBlock);
127     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
128         op, adaptor.getArg(), continuationBlock, failureBlock);
129 
130     return success();
131   }
132 
133 private:
134   /// If set to `false`, messages are printed but program execution continues.
135   /// This is useful for testing asserts.
136   bool abortOnFailedAssert = true;
137 };
138 
139 /// The cf->LLVM lowerings for branching ops require that the blocks they jump
140 /// to first have updated types which should be handled by a pattern operating
141 /// on the parent op.
142 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
143                                           ValueRange operands,
144                                           ValueRange blockArgs, Location loc,
145                                           llvm::StringRef messagePrefix) {
146   for (const auto &idxAndTypes :
147        llvm::enumerate(llvm::zip(blockArgs, operands))) {
148     int64_t i = idxAndTypes.index();
149     Value argValue =
150         rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
151     Type operandType = std::get<1>(idxAndTypes.value()).getType();
152     // In the case of an invalid jump, the block argument will have been
153     // remapped to an UnrealizedConversionCast. In the case of a valid jump,
154     // there might still be a no-op conversion cast with both types being equal.
155     // Consider both of these details to see if the jump would be invalid.
156     if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
157             argValue.getDefiningOp())) {
158       if (op.getOperandTypes().front() != operandType) {
159         return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
160           diag << messagePrefix;
161           diag << "mismatched types from operand # " << i << " ";
162           diag << operandType;
163           diag << " not compatible with destination block argument type ";
164           diag << op.getOperandTypes().front();
165           diag << " which should be converted with the parent op.";
166         });
167       }
168     }
169   }
170   return success();
171 }
172 
173 /// Ensure that all block types were updated and then create an LLVM::BrOp
174 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
175   using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
176 
177   LogicalResult
178   matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
179                   ConversionPatternRewriter &rewriter) const override {
180     if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
181                                     op.getSuccessor()->getArguments(),
182                                     op.getLoc(),
183                                     /*messagePrefix=*/"")))
184       return failure();
185 
186     rewriter.replaceOpWithNewOp<LLVM::BrOp>(
187         op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
188     return success();
189   }
190 };
191 
192 /// Ensure that all block types were updated and then create an LLVM::CondBrOp
193 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
194   using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
195 
196   LogicalResult
197   matchAndRewrite(cf::CondBranchOp op,
198                   typename cf::CondBranchOp::Adaptor adaptor,
199                   ConversionPatternRewriter &rewriter) const override {
200     if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
201                                     op.getFalseDest()->getArguments(),
202                                     op.getLoc(), "in false case branch ")))
203       return failure();
204     if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
205                                     op.getTrueDest()->getArguments(),
206                                     op.getLoc(), "in true case branch ")))
207       return failure();
208 
209     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
210         op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
211     return success();
212   }
213 };
214 
215 /// Ensure that all block types were updated and then create an LLVM::SwitchOp
216 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
217   using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
218 
219   LogicalResult
220   matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
221                   ConversionPatternRewriter &rewriter) const override {
222     if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
223                                     op.getDefaultDestination()->getArguments(),
224                                     op.getLoc(), "in switch default case ")))
225       return failure();
226 
227     for (const auto &i : llvm::enumerate(
228              llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
229       if (failed(verifyMatchingValues(
230               rewriter, std::get<0>(i.value()),
231               std::get<1>(i.value())->getArguments(), op.getLoc(),
232               "in switch case " + std::to_string(i.index()) + " "))) {
233         return failure();
234       }
235     }
236 
237     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
238         op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
239     return success();
240   }
241 };
242 
243 } // namespace
244 
245 void mlir::cf::populateControlFlowToLLVMConversionPatterns(
246     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
247   // clang-format off
248   patterns.add<
249       AssertOpLowering,
250       BranchOpLowering,
251       CondBranchOpLowering,
252       SwitchOpLowering>(converter);
253   // clang-format on
254 }
255 
256 void mlir::cf::populateAssertToLLVMConversionPattern(
257     LLVMTypeConverter &converter, RewritePatternSet &patterns,
258     bool abortOnFailure) {
259   patterns.add<AssertOpLowering>(converter, abortOnFailure);
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // Pass Definition
264 //===----------------------------------------------------------------------===//
265 
266 namespace {
267 /// A pass converting MLIR operations into the LLVM IR dialect.
268 struct ConvertControlFlowToLLVM
269     : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
270 
271   using Base::Base;
272 
273   /// Run the dialect converter on the module.
274   void runOnOperation() override {
275     LLVMConversionTarget target(getContext());
276     RewritePatternSet patterns(&getContext());
277 
278     LowerToLLVMOptions options(&getContext());
279     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
280       options.overrideIndexBitwidth(indexBitwidth);
281     options.useOpaquePointers = useOpaquePointers;
282 
283     LLVMTypeConverter converter(&getContext(), options);
284     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
285 
286     if (failed(applyPartialConversion(getOperation(), target,
287                                       std::move(patterns))))
288       signalPassFailure();
289   }
290 };
291 } // namespace
292 
293 //===----------------------------------------------------------------------===//
294 // ConvertToLLVMPatternInterface implementation
295 //===----------------------------------------------------------------------===//
296 
297 namespace {
298 /// Implement the interface to convert MemRef to LLVM.
299 struct ControlFlowToLLVMDialectInterface
300     : public ConvertToLLVMPatternInterface {
301   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
302   void loadDependentDialects(MLIRContext *context) const final {
303     context->loadDialect<LLVM::LLVMDialect>();
304   }
305 
306   /// Hook for derived dialect interface to provide conversion patterns
307   /// and mark dialect legal for the conversion target.
308   void populateConvertToLLVMConversionPatterns(
309       ConversionTarget &target, LLVMTypeConverter &typeConverter,
310       RewritePatternSet &patterns) const final {
311     mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
312                                                           patterns);
313   }
314 };
315 } // namespace
316 
317 void mlir::cf::registerConvertControlFlowToLLVMInterface(
318     DialectRegistry &registry) {
319   registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
320     dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
321   });
322 }
323