1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR standard and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 17 #include "mlir/Conversion/LLVMCommon/Pattern.h" 18 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Pass/Pass.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "llvm/ADT/StringRef.h" 27 #include <functional> 28 29 namespace mlir { 30 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVM 31 #include "mlir/Conversion/Passes.h.inc" 32 } // namespace mlir 33 34 using namespace mlir; 35 36 #define PASS_NAME "convert-cf-to-llvm" 37 38 static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) { 39 std::string prefix = "assert_msg_"; 40 int counter = 0; 41 while (moduleOp.lookupSymbol(prefix + std::to_string(counter))) 42 ++counter; 43 return prefix + std::to_string(counter); 44 } 45 46 /// Generate IR that prints the given string to stderr. 47 static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp, 48 StringRef msg) { 49 auto ip = builder.saveInsertionPoint(); 50 builder.setInsertionPointToStart(moduleOp.getBody()); 51 MLIRContext *ctx = builder.getContext(); 52 53 // Create a zero-terminated byte representation and allocate global symbol. 54 SmallVector<uint8_t> elementVals; 55 elementVals.append(msg.begin(), msg.end()); 56 elementVals.push_back(0); 57 auto dataAttrType = RankedTensorType::get( 58 {static_cast<int64_t>(elementVals.size())}, builder.getI8Type()); 59 auto dataAttr = 60 DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals)); 61 auto arrayTy = 62 LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); 63 std::string symbolName = generateGlobalMsgSymbolName(moduleOp); 64 auto globalOp = builder.create<LLVM::GlobalOp>( 65 loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName, 66 dataAttr); 67 68 // Emit call to `printStr` in runtime library. 69 builder.restoreInsertionPoint(ip); 70 auto msgAddr = builder.create<LLVM::AddressOfOp>( 71 loc, LLVM::LLVMPointerType::get(arrayTy), globalOp.getName()); 72 SmallVector<LLVM::GEPArg> indices(1, 0); 73 Value gep = builder.create<LLVM::GEPOp>( 74 loc, LLVM::LLVMPointerType::get(builder.getI8Type()), msgAddr, indices); 75 Operation *printer = LLVM::lookupOrCreatePrintStrFn(moduleOp); 76 builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer), 77 gep); 78 } 79 80 namespace { 81 /// Lower `cf.assert`. The default lowering calls the `abort` function if the 82 /// assertion is violated and has no effect otherwise. The failure message is 83 /// ignored by the default lowering but should be propagated by any custom 84 /// lowering. 85 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { 86 explicit AssertOpLowering(LLVMTypeConverter &typeConverter, 87 bool abortOnFailedAssert = true) 88 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), 89 abortOnFailedAssert(abortOnFailedAssert) {} 90 91 LogicalResult 92 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 93 ConversionPatternRewriter &rewriter) const override { 94 auto loc = op.getLoc(); 95 auto module = op->getParentOfType<ModuleOp>(); 96 97 // Split block at `assert` operation. 98 Block *opBlock = rewriter.getInsertionBlock(); 99 auto opPosition = rewriter.getInsertionPoint(); 100 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); 101 102 // Failed block: Generate IR to print the message and call `abort`. 103 Block *failureBlock = rewriter.createBlock(opBlock->getParent()); 104 createPrintMsg(rewriter, loc, module, op.getMsg()); 105 if (abortOnFailedAssert) { 106 // Insert the `abort` declaration if necessary. 107 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); 108 if (!abortFunc) { 109 OpBuilder::InsertionGuard guard(rewriter); 110 rewriter.setInsertionPointToStart(module.getBody()); 111 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); 112 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), 113 "abort", abortFuncTy); 114 } 115 rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt); 116 rewriter.create<LLVM::UnreachableOp>(loc); 117 } else { 118 rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock); 119 } 120 121 // Generate assertion test. 122 rewriter.setInsertionPointToEnd(opBlock); 123 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 124 op, adaptor.getArg(), continuationBlock, failureBlock); 125 126 return success(); 127 } 128 129 private: 130 /// If set to `false`, messages are printed but program execution continues. 131 /// This is useful for testing asserts. 132 bool abortOnFailedAssert = true; 133 }; 134 135 /// The cf->LLVM lowerings for branching ops require that the blocks they jump 136 /// to first have updated types which should be handled by a pattern operating 137 /// on the parent op. 138 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter, 139 ValueRange operands, 140 ValueRange blockArgs, Location loc, 141 llvm::StringRef messagePrefix) { 142 for (const auto &idxAndTypes : 143 llvm::enumerate(llvm::zip(blockArgs, operands))) { 144 int64_t i = idxAndTypes.index(); 145 Value argValue = 146 rewriter.getRemappedValue(std::get<0>(idxAndTypes.value())); 147 Type operandType = std::get<1>(idxAndTypes.value()).getType(); 148 // In the case of an invalid jump, the block argument will have been 149 // remapped to an UnrealizedConversionCast. In the case of a valid jump, 150 // there might still be a no-op conversion cast with both types being equal. 151 // Consider both of these details to see if the jump would be invalid. 152 if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>( 153 argValue.getDefiningOp())) { 154 if (op.getOperandTypes().front() != operandType) { 155 return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) { 156 diag << messagePrefix; 157 diag << "mismatched types from operand # " << i << " "; 158 diag << operandType; 159 diag << " not compatible with destination block argument type "; 160 diag << op.getOperandTypes().front(); 161 diag << " which should be converted with the parent op."; 162 }); 163 } 164 } 165 } 166 return success(); 167 } 168 169 /// Ensure that all block types were updated and then create an LLVM::BrOp 170 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { 171 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; 172 173 LogicalResult 174 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, 175 ConversionPatternRewriter &rewriter) const override { 176 if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(), 177 op.getSuccessor()->getArguments(), 178 op.getLoc(), 179 /*messagePrefix=*/""))) 180 return failure(); 181 182 rewriter.replaceOpWithNewOp<LLVM::BrOp>( 183 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 184 return success(); 185 } 186 }; 187 188 /// Ensure that all block types were updated and then create an LLVM::CondBrOp 189 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { 190 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; 191 192 LogicalResult 193 matchAndRewrite(cf::CondBranchOp op, 194 typename cf::CondBranchOp::Adaptor adaptor, 195 ConversionPatternRewriter &rewriter) const override { 196 if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(), 197 op.getFalseDest()->getArguments(), 198 op.getLoc(), "in false case branch "))) 199 return failure(); 200 if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(), 201 op.getTrueDest()->getArguments(), 202 op.getLoc(), "in true case branch "))) 203 return failure(); 204 205 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 206 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 207 return success(); 208 } 209 }; 210 211 /// Ensure that all block types were updated and then create an LLVM::SwitchOp 212 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { 213 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; 214 215 LogicalResult 216 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, 217 ConversionPatternRewriter &rewriter) const override { 218 if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(), 219 op.getDefaultDestination()->getArguments(), 220 op.getLoc(), "in switch default case "))) 221 return failure(); 222 223 for (const auto &i : llvm::enumerate( 224 llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) { 225 if (failed(verifyMatchingValues( 226 rewriter, std::get<0>(i.value()), 227 std::get<1>(i.value())->getArguments(), op.getLoc(), 228 "in switch case " + std::to_string(i.index()) + " "))) { 229 return failure(); 230 } 231 } 232 233 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 234 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 235 return success(); 236 } 237 }; 238 239 } // namespace 240 241 void mlir::cf::populateControlFlowToLLVMConversionPatterns( 242 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 243 // clang-format off 244 patterns.add< 245 AssertOpLowering, 246 BranchOpLowering, 247 CondBranchOpLowering, 248 SwitchOpLowering>(converter); 249 // clang-format on 250 } 251 252 void mlir::cf::populateAssertToLLVMConversionPattern( 253 LLVMTypeConverter &converter, RewritePatternSet &patterns, 254 bool abortOnFailure) { 255 patterns.add<AssertOpLowering>(converter, abortOnFailure); 256 } 257 258 //===----------------------------------------------------------------------===// 259 // Pass Definition 260 //===----------------------------------------------------------------------===// 261 262 namespace { 263 /// A pass converting MLIR operations into the LLVM IR dialect. 264 struct ConvertControlFlowToLLVM 265 : public impl::ConvertControlFlowToLLVMBase<ConvertControlFlowToLLVM> { 266 ConvertControlFlowToLLVM() = default; 267 268 /// Run the dialect converter on the module. 269 void runOnOperation() override { 270 LLVMConversionTarget target(getContext()); 271 RewritePatternSet patterns(&getContext()); 272 273 LowerToLLVMOptions options(&getContext()); 274 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 275 options.overrideIndexBitwidth(indexBitwidth); 276 277 LLVMTypeConverter converter(&getContext(), options); 278 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 279 280 if (failed(applyPartialConversion(getOperation(), target, 281 std::move(patterns)))) 282 signalPassFailure(); 283 } 284 }; 285 } // namespace 286 287 std::unique_ptr<Pass> mlir::cf::createConvertControlFlowToLLVMPass() { 288 return std::make_unique<ConvertControlFlowToLLVM>(); 289 } 290