1 //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert MLIR standard and builtin dialects 10 // into the LLVM IR dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 15 16 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 17 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 18 #include "mlir/Conversion/LLVMCommon/Pattern.h" 19 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 21 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 23 #include "mlir/IR/BuiltinOps.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 #include "llvm/ADT/StringRef.h" 28 #include <functional> 29 30 namespace mlir { 31 #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS 32 #include "mlir/Conversion/Passes.h.inc" 33 } // namespace mlir 34 35 using namespace mlir; 36 37 #define PASS_NAME "convert-cf-to-llvm" 38 39 static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) { 40 std::string prefix = "assert_msg_"; 41 int counter = 0; 42 while (moduleOp.lookupSymbol(prefix + std::to_string(counter))) 43 ++counter; 44 return prefix + std::to_string(counter); 45 } 46 47 /// Generate IR that prints the given string to stderr. 48 static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp, 49 StringRef msg, 50 const LLVMTypeConverter &typeConverter) { 51 auto ip = builder.saveInsertionPoint(); 52 builder.setInsertionPointToStart(moduleOp.getBody()); 53 MLIRContext *ctx = builder.getContext(); 54 55 // Create a zero-terminated byte representation and allocate global symbol. 56 SmallVector<uint8_t> elementVals; 57 elementVals.append(msg.begin(), msg.end()); 58 elementVals.push_back(0); 59 auto dataAttrType = RankedTensorType::get( 60 {static_cast<int64_t>(elementVals.size())}, builder.getI8Type()); 61 auto dataAttr = 62 DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals)); 63 auto arrayTy = 64 LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size()); 65 std::string symbolName = generateGlobalMsgSymbolName(moduleOp); 66 auto globalOp = builder.create<LLVM::GlobalOp>( 67 loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName, 68 dataAttr); 69 70 // Emit call to `printStr` in runtime library. 71 builder.restoreInsertionPoint(ip); 72 auto msgAddr = builder.create<LLVM::AddressOfOp>( 73 loc, typeConverter.getPointerType(arrayTy), globalOp.getName()); 74 SmallVector<LLVM::GEPArg> indices(1, 0); 75 Value gep = builder.create<LLVM::GEPOp>( 76 loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr, 77 indices); 78 Operation *printer = LLVM::lookupOrCreatePrintStrFn( 79 moduleOp, typeConverter.useOpaquePointers()); 80 builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer), 81 gep); 82 } 83 84 namespace { 85 /// Lower `cf.assert`. The default lowering calls the `abort` function if the 86 /// assertion is violated and has no effect otherwise. The failure message is 87 /// ignored by the default lowering but should be propagated by any custom 88 /// lowering. 89 struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { 90 explicit AssertOpLowering(LLVMTypeConverter &typeConverter, 91 bool abortOnFailedAssert = true) 92 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), 93 abortOnFailedAssert(abortOnFailedAssert) {} 94 95 LogicalResult 96 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, 97 ConversionPatternRewriter &rewriter) const override { 98 auto loc = op.getLoc(); 99 auto module = op->getParentOfType<ModuleOp>(); 100 101 // Split block at `assert` operation. 102 Block *opBlock = rewriter.getInsertionBlock(); 103 auto opPosition = rewriter.getInsertionPoint(); 104 Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); 105 106 // Failed block: Generate IR to print the message and call `abort`. 107 Block *failureBlock = rewriter.createBlock(opBlock->getParent()); 108 createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter()); 109 if (abortOnFailedAssert) { 110 // Insert the `abort` declaration if necessary. 111 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); 112 if (!abortFunc) { 113 OpBuilder::InsertionGuard guard(rewriter); 114 rewriter.setInsertionPointToStart(module.getBody()); 115 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); 116 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), 117 "abort", abortFuncTy); 118 } 119 rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt); 120 rewriter.create<LLVM::UnreachableOp>(loc); 121 } else { 122 rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock); 123 } 124 125 // Generate assertion test. 126 rewriter.setInsertionPointToEnd(opBlock); 127 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 128 op, adaptor.getArg(), continuationBlock, failureBlock); 129 130 return success(); 131 } 132 133 private: 134 /// If set to `false`, messages are printed but program execution continues. 135 /// This is useful for testing asserts. 136 bool abortOnFailedAssert = true; 137 }; 138 139 /// The cf->LLVM lowerings for branching ops require that the blocks they jump 140 /// to first have updated types which should be handled by a pattern operating 141 /// on the parent op. 142 static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter, 143 ValueRange operands, 144 ValueRange blockArgs, Location loc, 145 llvm::StringRef messagePrefix) { 146 for (const auto &idxAndTypes : 147 llvm::enumerate(llvm::zip(blockArgs, operands))) { 148 int64_t i = idxAndTypes.index(); 149 Value argValue = 150 rewriter.getRemappedValue(std::get<0>(idxAndTypes.value())); 151 Type operandType = std::get<1>(idxAndTypes.value()).getType(); 152 // In the case of an invalid jump, the block argument will have been 153 // remapped to an UnrealizedConversionCast. In the case of a valid jump, 154 // there might still be a no-op conversion cast with both types being equal. 155 // Consider both of these details to see if the jump would be invalid. 156 if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>( 157 argValue.getDefiningOp())) { 158 if (op.getOperandTypes().front() != operandType) { 159 return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) { 160 diag << messagePrefix; 161 diag << "mismatched types from operand # " << i << " "; 162 diag << operandType; 163 diag << " not compatible with destination block argument type "; 164 diag << op.getOperandTypes().front(); 165 diag << " which should be converted with the parent op."; 166 }); 167 } 168 } 169 } 170 return success(); 171 } 172 173 /// Ensure that all block types were updated and then create an LLVM::BrOp 174 struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { 175 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; 176 177 LogicalResult 178 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, 179 ConversionPatternRewriter &rewriter) const override { 180 if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(), 181 op.getSuccessor()->getArguments(), 182 op.getLoc(), 183 /*messagePrefix=*/""))) 184 return failure(); 185 186 rewriter.replaceOpWithNewOp<LLVM::BrOp>( 187 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 188 return success(); 189 } 190 }; 191 192 /// Ensure that all block types were updated and then create an LLVM::CondBrOp 193 struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { 194 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; 195 196 LogicalResult 197 matchAndRewrite(cf::CondBranchOp op, 198 typename cf::CondBranchOp::Adaptor adaptor, 199 ConversionPatternRewriter &rewriter) const override { 200 if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(), 201 op.getFalseDest()->getArguments(), 202 op.getLoc(), "in false case branch "))) 203 return failure(); 204 if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(), 205 op.getTrueDest()->getArguments(), 206 op.getLoc(), "in true case branch "))) 207 return failure(); 208 209 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 210 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 211 return success(); 212 } 213 }; 214 215 /// Ensure that all block types were updated and then create an LLVM::SwitchOp 216 struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { 217 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; 218 219 LogicalResult 220 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, 221 ConversionPatternRewriter &rewriter) const override { 222 if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(), 223 op.getDefaultDestination()->getArguments(), 224 op.getLoc(), "in switch default case "))) 225 return failure(); 226 227 for (const auto &i : llvm::enumerate( 228 llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) { 229 if (failed(verifyMatchingValues( 230 rewriter, std::get<0>(i.value()), 231 std::get<1>(i.value())->getArguments(), op.getLoc(), 232 "in switch case " + std::to_string(i.index()) + " "))) { 233 return failure(); 234 } 235 } 236 237 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( 238 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); 239 return success(); 240 } 241 }; 242 243 } // namespace 244 245 void mlir::cf::populateControlFlowToLLVMConversionPatterns( 246 LLVMTypeConverter &converter, RewritePatternSet &patterns) { 247 // clang-format off 248 patterns.add< 249 AssertOpLowering, 250 BranchOpLowering, 251 CondBranchOpLowering, 252 SwitchOpLowering>(converter); 253 // clang-format on 254 } 255 256 void mlir::cf::populateAssertToLLVMConversionPattern( 257 LLVMTypeConverter &converter, RewritePatternSet &patterns, 258 bool abortOnFailure) { 259 patterns.add<AssertOpLowering>(converter, abortOnFailure); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // Pass Definition 264 //===----------------------------------------------------------------------===// 265 266 namespace { 267 /// A pass converting MLIR operations into the LLVM IR dialect. 268 struct ConvertControlFlowToLLVM 269 : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> { 270 271 using Base::Base; 272 273 /// Run the dialect converter on the module. 274 void runOnOperation() override { 275 LLVMConversionTarget target(getContext()); 276 RewritePatternSet patterns(&getContext()); 277 278 LowerToLLVMOptions options(&getContext()); 279 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 280 options.overrideIndexBitwidth(indexBitwidth); 281 options.useOpaquePointers = useOpaquePointers; 282 283 LLVMTypeConverter converter(&getContext(), options); 284 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 285 286 if (failed(applyPartialConversion(getOperation(), target, 287 std::move(patterns)))) 288 signalPassFailure(); 289 } 290 }; 291 } // namespace 292 293 //===----------------------------------------------------------------------===// 294 // ConvertToLLVMPatternInterface implementation 295 //===----------------------------------------------------------------------===// 296 297 namespace { 298 /// Implement the interface to convert MemRef to LLVM. 299 struct ControlFlowToLLVMDialectInterface 300 : public ConvertToLLVMPatternInterface { 301 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 302 void loadDependentDialects(MLIRContext *context) const final { 303 context->loadDialect<LLVM::LLVMDialect>(); 304 } 305 306 /// Hook for derived dialect interface to provide conversion patterns 307 /// and mark dialect legal for the conversion target. 308 void populateConvertToLLVMConversionPatterns( 309 ConversionTarget &target, LLVMTypeConverter &typeConverter, 310 RewritePatternSet &patterns) const final { 311 mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, 312 patterns); 313 } 314 }; 315 } // namespace 316 317 void mlir::cf::registerConvertControlFlowToLLVMInterface( 318 DialectRegistry ®istry) { 319 registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { 320 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>(); 321 }); 322 } 323