xref: /llvm-project/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (revision 52aa6bd3fd6f48f583beaadcbb53edc0f3def4a1)
1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert MLIR standard and builtin dialects
10 // into the LLVM IR dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15 
16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18 #include "mlir/Conversion/LLVMCommon/Pattern.h"
19 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
21 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/ADT/StringRef.h"
28 #include <functional>
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 
37 #define PASS_NAME "convert-cf-to-llvm"
38 
39 static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
40   std::string prefix = "assert_msg_";
41   int counter = 0;
42   while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
43     ++counter;
44   return prefix + std::to_string(counter);
45 }
46 
47 /// Generate IR that prints the given string to stderr.
48 static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
49                            StringRef msg, LLVMTypeConverter &typeConverter) {
50   auto ip = builder.saveInsertionPoint();
51   builder.setInsertionPointToStart(moduleOp.getBody());
52   MLIRContext *ctx = builder.getContext();
53 
54   // Create a zero-terminated byte representation and allocate global symbol.
55   SmallVector<uint8_t> elementVals;
56   elementVals.append(msg.begin(), msg.end());
57   elementVals.push_back(0);
58   auto dataAttrType = RankedTensorType::get(
59       {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
60   auto dataAttr =
61       DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
62   auto arrayTy =
63       LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
64   std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
65   auto globalOp = builder.create<LLVM::GlobalOp>(
66       loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
67       dataAttr);
68 
69   // Emit call to `printStr` in runtime library.
70   builder.restoreInsertionPoint(ip);
71   auto msgAddr = builder.create<LLVM::AddressOfOp>(
72       loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
73   SmallVector<LLVM::GEPArg> indices(1, 0);
74   Value gep = builder.create<LLVM::GEPOp>(
75       loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
76       indices);
77   Operation *printer = LLVM::lookupOrCreatePrintStrFn(
78       moduleOp, typeConverter.useOpaquePointers());
79   builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
80                                gep);
81 }
82 
83 namespace {
84 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
85 /// assertion is violated and has no effect otherwise. The failure message is
86 /// ignored by the default lowering but should be propagated by any custom
87 /// lowering.
88 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
89   explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
90                             bool abortOnFailedAssert = true)
91       : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
92         abortOnFailedAssert(abortOnFailedAssert) {}
93 
94   LogicalResult
95   matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
96                   ConversionPatternRewriter &rewriter) const override {
97     auto loc = op.getLoc();
98     auto module = op->getParentOfType<ModuleOp>();
99 
100     // Split block at `assert` operation.
101     Block *opBlock = rewriter.getInsertionBlock();
102     auto opPosition = rewriter.getInsertionPoint();
103     Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
104 
105     // Failed block: Generate IR to print the message and call `abort`.
106     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
107     createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
108     if (abortOnFailedAssert) {
109       // Insert the `abort` declaration if necessary.
110       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
111       if (!abortFunc) {
112         OpBuilder::InsertionGuard guard(rewriter);
113         rewriter.setInsertionPointToStart(module.getBody());
114         auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
115         abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
116                                                       "abort", abortFuncTy);
117       }
118       rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
119       rewriter.create<LLVM::UnreachableOp>(loc);
120     } else {
121       rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
122     }
123 
124     // Generate assertion test.
125     rewriter.setInsertionPointToEnd(opBlock);
126     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
127         op, adaptor.getArg(), continuationBlock, failureBlock);
128 
129     return success();
130   }
131 
132 private:
133   /// If set to `false`, messages are printed but program execution continues.
134   /// This is useful for testing asserts.
135   bool abortOnFailedAssert = true;
136 };
137 
138 /// The cf->LLVM lowerings for branching ops require that the blocks they jump
139 /// to first have updated types which should be handled by a pattern operating
140 /// on the parent op.
141 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
142                                           ValueRange operands,
143                                           ValueRange blockArgs, Location loc,
144                                           llvm::StringRef messagePrefix) {
145   for (const auto &idxAndTypes :
146        llvm::enumerate(llvm::zip(blockArgs, operands))) {
147     int64_t i = idxAndTypes.index();
148     Value argValue =
149         rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
150     Type operandType = std::get<1>(idxAndTypes.value()).getType();
151     // In the case of an invalid jump, the block argument will have been
152     // remapped to an UnrealizedConversionCast. In the case of a valid jump,
153     // there might still be a no-op conversion cast with both types being equal.
154     // Consider both of these details to see if the jump would be invalid.
155     if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
156             argValue.getDefiningOp())) {
157       if (op.getOperandTypes().front() != operandType) {
158         return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
159           diag << messagePrefix;
160           diag << "mismatched types from operand # " << i << " ";
161           diag << operandType;
162           diag << " not compatible with destination block argument type ";
163           diag << op.getOperandTypes().front();
164           diag << " which should be converted with the parent op.";
165         });
166       }
167     }
168   }
169   return success();
170 }
171 
172 /// Ensure that all block types were updated and then create an LLVM::BrOp
173 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
174   using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
175 
176   LogicalResult
177   matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
178                   ConversionPatternRewriter &rewriter) const override {
179     if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
180                                     op.getSuccessor()->getArguments(),
181                                     op.getLoc(),
182                                     /*messagePrefix=*/"")))
183       return failure();
184 
185     rewriter.replaceOpWithNewOp<LLVM::BrOp>(
186         op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
187     return success();
188   }
189 };
190 
191 /// Ensure that all block types were updated and then create an LLVM::CondBrOp
192 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
193   using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
194 
195   LogicalResult
196   matchAndRewrite(cf::CondBranchOp op,
197                   typename cf::CondBranchOp::Adaptor adaptor,
198                   ConversionPatternRewriter &rewriter) const override {
199     if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
200                                     op.getFalseDest()->getArguments(),
201                                     op.getLoc(), "in false case branch ")))
202       return failure();
203     if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
204                                     op.getTrueDest()->getArguments(),
205                                     op.getLoc(), "in true case branch ")))
206       return failure();
207 
208     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
209         op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
210     return success();
211   }
212 };
213 
214 /// Ensure that all block types were updated and then create an LLVM::SwitchOp
215 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
216   using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
217 
218   LogicalResult
219   matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
220                   ConversionPatternRewriter &rewriter) const override {
221     if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
222                                     op.getDefaultDestination()->getArguments(),
223                                     op.getLoc(), "in switch default case ")))
224       return failure();
225 
226     for (const auto &i : llvm::enumerate(
227              llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
228       if (failed(verifyMatchingValues(
229               rewriter, std::get<0>(i.value()),
230               std::get<1>(i.value())->getArguments(), op.getLoc(),
231               "in switch case " + std::to_string(i.index()) + " "))) {
232         return failure();
233       }
234     }
235 
236     rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
237         op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
238     return success();
239   }
240 };
241 
242 } // namespace
243 
244 void mlir::cf::populateControlFlowToLLVMConversionPatterns(
245     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
246   // clang-format off
247   patterns.add<
248       AssertOpLowering,
249       BranchOpLowering,
250       CondBranchOpLowering,
251       SwitchOpLowering>(converter);
252   // clang-format on
253 }
254 
255 void mlir::cf::populateAssertToLLVMConversionPattern(
256     LLVMTypeConverter &converter, RewritePatternSet &patterns,
257     bool abortOnFailure) {
258   patterns.add<AssertOpLowering>(converter, abortOnFailure);
259 }
260 
261 //===----------------------------------------------------------------------===//
262 // Pass Definition
263 //===----------------------------------------------------------------------===//
264 
265 namespace {
266 /// A pass converting MLIR operations into the LLVM IR dialect.
267 struct ConvertControlFlowToLLVM
268     : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
269 
270   using Base::Base;
271 
272   /// Run the dialect converter on the module.
273   void runOnOperation() override {
274     LLVMConversionTarget target(getContext());
275     RewritePatternSet patterns(&getContext());
276 
277     LowerToLLVMOptions options(&getContext());
278     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
279       options.overrideIndexBitwidth(indexBitwidth);
280     options.useOpaquePointers = useOpaquePointers;
281 
282     LLVMTypeConverter converter(&getContext(), options);
283     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
284 
285     if (failed(applyPartialConversion(getOperation(), target,
286                                       std::move(patterns))))
287       signalPassFailure();
288   }
289 };
290 } // namespace
291 
292 //===----------------------------------------------------------------------===//
293 // ConvertToLLVMPatternInterface implementation
294 //===----------------------------------------------------------------------===//
295 
296 namespace {
297 /// Implement the interface to convert MemRef to LLVM.
298 struct ControlFlowToLLVMDialectInterface
299     : public ConvertToLLVMPatternInterface {
300   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
301   void loadDependentDialects(MLIRContext *context) const final {
302     context->loadDialect<LLVM::LLVMDialect>();
303   }
304 
305   /// Hook for derived dialect interface to provide conversion patterns
306   /// and mark dialect legal for the conversion target.
307   void populateConvertToLLVMConversionPatterns(
308       ConversionTarget &target, LLVMTypeConverter &typeConverter,
309       RewritePatternSet &patterns) const final {
310     mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
311                                                           patterns);
312   }
313 };
314 } // namespace
315 
316 void mlir::cf::registerConvertControlFlowToLLVMInterface(
317     DialectRegistry &registry) {
318   registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
319     dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
320   });
321 }
322