1 //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 10 11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Transforms/Passes.h" 17 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Pass/Pass.h" 21 #include <type_traits> 22 23 namespace mlir { 24 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS 25 #include "mlir/Conversion/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 30 namespace { 31 32 /// Operations whose conversion will depend on whether they are passed a 33 /// rounding mode attribute or not. 34 /// 35 /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower 36 /// to; `AttrConvert` is the attribute conversion to convert the rounding mode 37 /// attribute. 38 template <typename SourceOp, typename TargetOp, bool Constrained, 39 template <typename, typename> typename AttrConvert = 40 AttrConvertPassThrough> 41 struct ConstrainedVectorConvertToLLVMPattern 42 : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> { 43 using VectorConvertToLLVMPattern<SourceOp, TargetOp, 44 AttrConvert>::VectorConvertToLLVMPattern; 45 46 LogicalResult 47 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 48 ConversionPatternRewriter &rewriter) const override { 49 if (Constrained != static_cast<bool>(op.getRoundingModeAttr())) 50 return failure(); 51 return VectorConvertToLLVMPattern<SourceOp, TargetOp, 52 AttrConvert>::matchAndRewrite(op, adaptor, 53 rewriter); 54 } 55 }; 56 57 //===----------------------------------------------------------------------===// 58 // Straightforward Op Lowerings 59 //===----------------------------------------------------------------------===// 60 61 using AddFOpLowering = 62 VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp, 63 arith::AttrConvertFastMathToLLVM>; 64 using AddIOpLowering = 65 VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp, 66 arith::AttrConvertOverflowToLLVM>; 67 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>; 68 using BitcastOpLowering = 69 VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>; 70 using DivFOpLowering = 71 VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp, 72 arith::AttrConvertFastMathToLLVM>; 73 using DivSIOpLowering = 74 VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>; 75 using DivUIOpLowering = 76 VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>; 77 using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>; 78 using ExtSIOpLowering = 79 VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>; 80 using ExtUIOpLowering = 81 VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>; 82 using FPToSIOpLowering = 83 VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>; 84 using FPToUIOpLowering = 85 VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>; 86 using MaximumFOpLowering = 87 VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp, 88 arith::AttrConvertFastMathToLLVM>; 89 using MaxNumFOpLowering = 90 VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp, 91 arith::AttrConvertFastMathToLLVM>; 92 using MaxSIOpLowering = 93 VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>; 94 using MaxUIOpLowering = 95 VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>; 96 using MinimumFOpLowering = 97 VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp, 98 arith::AttrConvertFastMathToLLVM>; 99 using MinNumFOpLowering = 100 VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp, 101 arith::AttrConvertFastMathToLLVM>; 102 using MinSIOpLowering = 103 VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>; 104 using MinUIOpLowering = 105 VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>; 106 using MulFOpLowering = 107 VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp, 108 arith::AttrConvertFastMathToLLVM>; 109 using MulIOpLowering = 110 VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp, 111 arith::AttrConvertOverflowToLLVM>; 112 using NegFOpLowering = 113 VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp, 114 arith::AttrConvertFastMathToLLVM>; 115 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>; 116 using RemFOpLowering = 117 VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp, 118 arith::AttrConvertFastMathToLLVM>; 119 using RemSIOpLowering = 120 VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>; 121 using RemUIOpLowering = 122 VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>; 123 using SelectOpLowering = 124 VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>; 125 using ShLIOpLowering = 126 VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp, 127 arith::AttrConvertOverflowToLLVM>; 128 using ShRSIOpLowering = 129 VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>; 130 using ShRUIOpLowering = 131 VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>; 132 using SIToFPOpLowering = 133 VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>; 134 using SubFOpLowering = 135 VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp, 136 arith::AttrConvertFastMathToLLVM>; 137 using SubIOpLowering = 138 VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp, 139 arith::AttrConvertOverflowToLLVM>; 140 using TruncFOpLowering = 141 ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp, 142 false>; 143 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< 144 arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, 145 arith::AttrConverterConstrainedFPToLLVM>; 146 using TruncIOpLowering = 147 VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>; 148 using UIToFPOpLowering = 149 VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>; 150 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>; 151 152 //===----------------------------------------------------------------------===// 153 // Op Lowering Patterns 154 //===----------------------------------------------------------------------===// 155 156 /// Directly lower to LLVM op. 157 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> { 158 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 159 160 LogicalResult 161 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 162 ConversionPatternRewriter &rewriter) const override; 163 }; 164 165 /// The lowering of index_cast becomes an integer conversion since index 166 /// becomes an integer. If the bit width of the source and target integer 167 /// types is the same, just erase the cast. If the target type is wider, 168 /// sign-extend the value, otherwise truncate it. 169 template <typename OpTy, typename ExtCastTy> 170 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> { 171 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern; 172 173 LogicalResult 174 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 175 ConversionPatternRewriter &rewriter) const override; 176 }; 177 178 using IndexCastOpSILowering = 179 IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>; 180 using IndexCastOpUILowering = 181 IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>; 182 183 struct AddUIExtendedOpLowering 184 : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> { 185 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 186 187 LogicalResult 188 matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor, 189 ConversionPatternRewriter &rewriter) const override; 190 }; 191 192 template <typename ArithMulOp, bool IsSigned> 193 struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> { 194 using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern; 195 196 LogicalResult 197 matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor, 198 ConversionPatternRewriter &rewriter) const override; 199 }; 200 201 using MulSIExtendedOpLowering = 202 MulIExtendedOpLowering<arith::MulSIExtendedOp, true>; 203 using MulUIExtendedOpLowering = 204 MulIExtendedOpLowering<arith::MulUIExtendedOp, false>; 205 206 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> { 207 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 208 209 LogicalResult 210 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 211 ConversionPatternRewriter &rewriter) const override; 212 }; 213 214 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { 215 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; 216 217 LogicalResult 218 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 219 ConversionPatternRewriter &rewriter) const override; 220 }; 221 222 } // namespace 223 224 //===----------------------------------------------------------------------===// 225 // ConstantOpLowering 226 //===----------------------------------------------------------------------===// 227 228 LogicalResult 229 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, 230 ConversionPatternRewriter &rewriter) const { 231 return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), 232 adaptor.getOperands(), op->getAttrs(), 233 *getTypeConverter(), rewriter); 234 } 235 236 //===----------------------------------------------------------------------===// 237 // IndexCastOpLowering 238 //===----------------------------------------------------------------------===// 239 240 template <typename OpTy, typename ExtCastTy> 241 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite( 242 OpTy op, typename OpTy::Adaptor adaptor, 243 ConversionPatternRewriter &rewriter) const { 244 Type resultType = op.getResult().getType(); 245 Type targetElementType = 246 this->typeConverter->convertType(getElementTypeOrSelf(resultType)); 247 Type sourceElementType = 248 this->typeConverter->convertType(getElementTypeOrSelf(op.getIn())); 249 unsigned targetBits = targetElementType.getIntOrFloatBitWidth(); 250 unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth(); 251 252 if (targetBits == sourceBits) { 253 rewriter.replaceOp(op, adaptor.getIn()); 254 return success(); 255 } 256 257 // Handle the scalar and 1D vector cases. 258 Type operandType = adaptor.getIn().getType(); 259 if (!isa<LLVM::LLVMArrayType>(operandType)) { 260 Type targetType = this->typeConverter->convertType(resultType); 261 if (targetBits < sourceBits) 262 rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType, 263 adaptor.getIn()); 264 else 265 rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn()); 266 return success(); 267 } 268 269 if (!isa<VectorType>(resultType)) 270 return rewriter.notifyMatchFailure(op, "expected vector result type"); 271 272 return LLVM::detail::handleMultidimensionalVectors( 273 op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()), 274 [&](Type llvm1DVectorTy, ValueRange operands) -> Value { 275 typename OpTy::Adaptor adaptor(operands); 276 if (targetBits < sourceBits) { 277 return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy, 278 adaptor.getIn()); 279 } 280 return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy, 281 adaptor.getIn()); 282 }, 283 rewriter); 284 } 285 286 //===----------------------------------------------------------------------===// 287 // AddUIExtendedOpLowering 288 //===----------------------------------------------------------------------===// 289 290 LogicalResult AddUIExtendedOpLowering::matchAndRewrite( 291 arith::AddUIExtendedOp op, OpAdaptor adaptor, 292 ConversionPatternRewriter &rewriter) const { 293 Type operandType = adaptor.getLhs().getType(); 294 Type sumResultType = op.getSum().getType(); 295 Type overflowResultType = op.getOverflow().getType(); 296 297 if (!LLVM::isCompatibleType(operandType)) 298 return failure(); 299 300 MLIRContext *ctx = rewriter.getContext(); 301 Location loc = op.getLoc(); 302 303 // Handle the scalar and 1D vector cases. 304 if (!isa<LLVM::LLVMArrayType>(operandType)) { 305 Type newOverflowType = typeConverter->convertType(overflowResultType); 306 Type structType = 307 LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); 308 Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>( 309 loc, structType, adaptor.getLhs(), adaptor.getRhs()); 310 Value sumExtracted = 311 rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0); 312 Value overflowExtracted = 313 rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1); 314 rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); 315 return success(); 316 } 317 318 if (!isa<VectorType>(sumResultType)) 319 return rewriter.notifyMatchFailure(loc, "expected vector result types"); 320 321 return rewriter.notifyMatchFailure(loc, 322 "ND vector types are not supported yet"); 323 } 324 325 //===----------------------------------------------------------------------===// 326 // MulIExtendedOpLowering 327 //===----------------------------------------------------------------------===// 328 329 template <typename ArithMulOp, bool IsSigned> 330 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite( 331 ArithMulOp op, typename ArithMulOp::Adaptor adaptor, 332 ConversionPatternRewriter &rewriter) const { 333 Type resultType = adaptor.getLhs().getType(); 334 335 if (!LLVM::isCompatibleType(resultType)) 336 return failure(); 337 338 Location loc = op.getLoc(); 339 340 // Handle the scalar and 1D vector cases. Because LLVM does not have a 341 // matching extended multiplication intrinsic, perform regular multiplication 342 // on operands zero-extended to i(2*N) bits, and truncate the results back to 343 // iN types. 344 if (!isa<LLVM::LLVMArrayType>(resultType)) { 345 // Shift amount necessary to extract the high bits from widened result. 346 TypedAttr shiftValAttr; 347 348 if (auto intTy = dyn_cast<IntegerType>(resultType)) { 349 unsigned resultBitwidth = intTy.getWidth(); 350 auto attrTy = rewriter.getIntegerType(resultBitwidth * 2); 351 shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth); 352 } else { 353 auto vecTy = cast<VectorType>(resultType); 354 unsigned resultBitwidth = vecTy.getElementTypeBitWidth(); 355 auto attrTy = VectorType::get( 356 vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2)); 357 shiftValAttr = SplatElementsAttr::get( 358 attrTy, APInt(resultBitwidth * 2, resultBitwidth)); 359 } 360 Type wideType = shiftValAttr.getType(); 361 assert(LLVM::isCompatibleType(wideType) && 362 "LLVM dialect should support all signless integer types"); 363 364 using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>; 365 Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs()); 366 Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs()); 367 Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt); 368 369 // Split the 2*N-bit wide result into two N-bit values. 370 Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt); 371 Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr); 372 Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal); 373 Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt); 374 375 rewriter.replaceOp(op, {low, high}); 376 return success(); 377 } 378 379 if (!isa<VectorType>(resultType)) 380 return rewriter.notifyMatchFailure(op, "expected vector result type"); 381 382 return rewriter.notifyMatchFailure(op, 383 "ND vector types are not supported yet"); 384 } 385 386 //===----------------------------------------------------------------------===// 387 // CmpIOpLowering 388 //===----------------------------------------------------------------------===// 389 390 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums 391 // share numerical values so just cast. 392 template <typename LLVMPredType, typename PredType> 393 static LLVMPredType convertCmpPredicate(PredType pred) { 394 return static_cast<LLVMPredType>(pred); 395 } 396 397 LogicalResult 398 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, 399 ConversionPatternRewriter &rewriter) const { 400 Type operandType = adaptor.getLhs().getType(); 401 Type resultType = op.getResult().getType(); 402 403 // Handle the scalar and 1D vector cases. 404 if (!isa<LLVM::LLVMArrayType>(operandType)) { 405 rewriter.replaceOpWithNewOp<LLVM::ICmpOp>( 406 op, typeConverter->convertType(resultType), 407 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 408 adaptor.getLhs(), adaptor.getRhs()); 409 return success(); 410 } 411 412 if (!isa<VectorType>(resultType)) 413 return rewriter.notifyMatchFailure(op, "expected vector result type"); 414 415 return LLVM::detail::handleMultidimensionalVectors( 416 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 417 [&](Type llvm1DVectorTy, ValueRange operands) { 418 OpAdaptor adaptor(operands); 419 return rewriter.create<LLVM::ICmpOp>( 420 op.getLoc(), llvm1DVectorTy, 421 convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()), 422 adaptor.getLhs(), adaptor.getRhs()); 423 }, 424 rewriter); 425 } 426 427 //===----------------------------------------------------------------------===// 428 // CmpFOpLowering 429 //===----------------------------------------------------------------------===// 430 431 LogicalResult 432 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, 433 ConversionPatternRewriter &rewriter) const { 434 Type operandType = adaptor.getLhs().getType(); 435 Type resultType = op.getResult().getType(); 436 LLVM::FastmathFlags fmf = 437 arith::convertArithFastMathFlagsToLLVM(op.getFastmath()); 438 439 // Handle the scalar and 1D vector cases. 440 if (!isa<LLVM::LLVMArrayType>(operandType)) { 441 rewriter.replaceOpWithNewOp<LLVM::FCmpOp>( 442 op, typeConverter->convertType(resultType), 443 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 444 adaptor.getLhs(), adaptor.getRhs(), fmf); 445 return success(); 446 } 447 448 if (!isa<VectorType>(resultType)) 449 return rewriter.notifyMatchFailure(op, "expected vector result type"); 450 451 return LLVM::detail::handleMultidimensionalVectors( 452 op.getOperation(), adaptor.getOperands(), *getTypeConverter(), 453 [&](Type llvm1DVectorTy, ValueRange operands) { 454 OpAdaptor adaptor(operands); 455 return rewriter.create<LLVM::FCmpOp>( 456 op.getLoc(), llvm1DVectorTy, 457 convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()), 458 adaptor.getLhs(), adaptor.getRhs(), fmf); 459 }, 460 rewriter); 461 } 462 463 //===----------------------------------------------------------------------===// 464 // Pass Definition 465 //===----------------------------------------------------------------------===// 466 467 namespace { 468 struct ArithToLLVMConversionPass 469 : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> { 470 using Base::Base; 471 472 void runOnOperation() override { 473 LLVMConversionTarget target(getContext()); 474 RewritePatternSet patterns(&getContext()); 475 476 LowerToLLVMOptions options(&getContext()); 477 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) 478 options.overrideIndexBitwidth(indexBitwidth); 479 480 LLVMTypeConverter converter(&getContext(), options); 481 arith::populateCeilFloorDivExpandOpsPatterns(patterns); 482 arith::populateArithToLLVMConversionPatterns(converter, patterns); 483 484 if (failed(applyPartialConversion(getOperation(), target, 485 std::move(patterns)))) 486 signalPassFailure(); 487 } 488 }; 489 } // namespace 490 491 //===----------------------------------------------------------------------===// 492 // ConvertToLLVMPatternInterface implementation 493 //===----------------------------------------------------------------------===// 494 495 namespace { 496 /// Implement the interface to convert MemRef to LLVM. 497 struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 498 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 499 void loadDependentDialects(MLIRContext *context) const final { 500 context->loadDialect<LLVM::LLVMDialect>(); 501 } 502 503 /// Hook for derived dialect interface to provide conversion patterns 504 /// and mark dialect legal for the conversion target. 505 void populateConvertToLLVMConversionPatterns( 506 ConversionTarget &target, LLVMTypeConverter &typeConverter, 507 RewritePatternSet &patterns) const final { 508 arith::populateCeilFloorDivExpandOpsPatterns(patterns); 509 arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); 510 } 511 }; 512 } // namespace 513 514 void mlir::arith::registerConvertArithToLLVMInterface( 515 DialectRegistry ®istry) { 516 registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { 517 dialect->addInterfaces<ArithToLLVMDialectInterface>(); 518 }); 519 } 520 521 //===----------------------------------------------------------------------===// 522 // Pattern Population 523 //===----------------------------------------------------------------------===// 524 525 void mlir::arith::populateArithToLLVMConversionPatterns( 526 const LLVMTypeConverter &converter, RewritePatternSet &patterns) { 527 // clang-format off 528 patterns.add< 529 AddFOpLowering, 530 AddIOpLowering, 531 AndIOpLowering, 532 AddUIExtendedOpLowering, 533 BitcastOpLowering, 534 ConstantOpLowering, 535 CmpFOpLowering, 536 CmpIOpLowering, 537 DivFOpLowering, 538 DivSIOpLowering, 539 DivUIOpLowering, 540 ExtFOpLowering, 541 ExtSIOpLowering, 542 ExtUIOpLowering, 543 FPToSIOpLowering, 544 FPToUIOpLowering, 545 IndexCastOpSILowering, 546 IndexCastOpUILowering, 547 MaximumFOpLowering, 548 MaxNumFOpLowering, 549 MaxSIOpLowering, 550 MaxUIOpLowering, 551 MinimumFOpLowering, 552 MinNumFOpLowering, 553 MinSIOpLowering, 554 MinUIOpLowering, 555 MulFOpLowering, 556 MulIOpLowering, 557 MulSIExtendedOpLowering, 558 MulUIExtendedOpLowering, 559 NegFOpLowering, 560 OrIOpLowering, 561 RemFOpLowering, 562 RemSIOpLowering, 563 RemUIOpLowering, 564 SelectOpLowering, 565 ShLIOpLowering, 566 ShRSIOpLowering, 567 ShRUIOpLowering, 568 SIToFPOpLowering, 569 SubFOpLowering, 570 SubIOpLowering, 571 TruncFOpLowering, 572 ConstrainedTruncFOpLowering, 573 TruncIOpLowering, 574 UIToFPOpLowering, 575 XOrIOpLowering 576 >(converter); 577 // clang-format on 578 } 579