xref: /llvm-project/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (revision 3cbc73f71eef599e678197e445e11a98f8f61689)
1 //===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
10 
11 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
12 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Transforms/Passes.h"
17 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
21 #include <type_traits>
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 
32 /// Operations whose conversion will depend on whether they are passed a
33 /// rounding mode attribute or not.
34 ///
35 /// `SourceOp` is the source operation; `TargetOp`, the operation it will lower
36 /// to; `AttrConvert` is the attribute conversion to convert the rounding mode
37 /// attribute.
38 template <typename SourceOp, typename TargetOp, bool Constrained,
39           template <typename, typename> typename AttrConvert =
40               AttrConvertPassThrough>
41 struct ConstrainedVectorConvertToLLVMPattern
42     : public VectorConvertToLLVMPattern<SourceOp, TargetOp, AttrConvert> {
43   using VectorConvertToLLVMPattern<SourceOp, TargetOp,
44                                    AttrConvert>::VectorConvertToLLVMPattern;
45 
46   LogicalResult
47   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
48                   ConversionPatternRewriter &rewriter) const override {
49     if (Constrained != static_cast<bool>(op.getRoundingModeAttr()))
50       return failure();
51     return VectorConvertToLLVMPattern<SourceOp, TargetOp,
52                                       AttrConvert>::matchAndRewrite(op, adaptor,
53                                                                     rewriter);
54   }
55 };
56 
57 //===----------------------------------------------------------------------===//
58 // Straightforward Op Lowerings
59 //===----------------------------------------------------------------------===//
60 
61 using AddFOpLowering =
62     VectorConvertToLLVMPattern<arith::AddFOp, LLVM::FAddOp,
63                                arith::AttrConvertFastMathToLLVM>;
64 using AddIOpLowering =
65     VectorConvertToLLVMPattern<arith::AddIOp, LLVM::AddOp,
66                                arith::AttrConvertOverflowToLLVM>;
67 using AndIOpLowering = VectorConvertToLLVMPattern<arith::AndIOp, LLVM::AndOp>;
68 using BitcastOpLowering =
69     VectorConvertToLLVMPattern<arith::BitcastOp, LLVM::BitcastOp>;
70 using DivFOpLowering =
71     VectorConvertToLLVMPattern<arith::DivFOp, LLVM::FDivOp,
72                                arith::AttrConvertFastMathToLLVM>;
73 using DivSIOpLowering =
74     VectorConvertToLLVMPattern<arith::DivSIOp, LLVM::SDivOp>;
75 using DivUIOpLowering =
76     VectorConvertToLLVMPattern<arith::DivUIOp, LLVM::UDivOp>;
77 using ExtFOpLowering = VectorConvertToLLVMPattern<arith::ExtFOp, LLVM::FPExtOp>;
78 using ExtSIOpLowering =
79     VectorConvertToLLVMPattern<arith::ExtSIOp, LLVM::SExtOp>;
80 using ExtUIOpLowering =
81     VectorConvertToLLVMPattern<arith::ExtUIOp, LLVM::ZExtOp>;
82 using FPToSIOpLowering =
83     VectorConvertToLLVMPattern<arith::FPToSIOp, LLVM::FPToSIOp>;
84 using FPToUIOpLowering =
85     VectorConvertToLLVMPattern<arith::FPToUIOp, LLVM::FPToUIOp>;
86 using MaximumFOpLowering =
87     VectorConvertToLLVMPattern<arith::MaximumFOp, LLVM::MaximumOp,
88                                arith::AttrConvertFastMathToLLVM>;
89 using MaxNumFOpLowering =
90     VectorConvertToLLVMPattern<arith::MaxNumFOp, LLVM::MaxNumOp,
91                                arith::AttrConvertFastMathToLLVM>;
92 using MaxSIOpLowering =
93     VectorConvertToLLVMPattern<arith::MaxSIOp, LLVM::SMaxOp>;
94 using MaxUIOpLowering =
95     VectorConvertToLLVMPattern<arith::MaxUIOp, LLVM::UMaxOp>;
96 using MinimumFOpLowering =
97     VectorConvertToLLVMPattern<arith::MinimumFOp, LLVM::MinimumOp,
98                                arith::AttrConvertFastMathToLLVM>;
99 using MinNumFOpLowering =
100     VectorConvertToLLVMPattern<arith::MinNumFOp, LLVM::MinNumOp,
101                                arith::AttrConvertFastMathToLLVM>;
102 using MinSIOpLowering =
103     VectorConvertToLLVMPattern<arith::MinSIOp, LLVM::SMinOp>;
104 using MinUIOpLowering =
105     VectorConvertToLLVMPattern<arith::MinUIOp, LLVM::UMinOp>;
106 using MulFOpLowering =
107     VectorConvertToLLVMPattern<arith::MulFOp, LLVM::FMulOp,
108                                arith::AttrConvertFastMathToLLVM>;
109 using MulIOpLowering =
110     VectorConvertToLLVMPattern<arith::MulIOp, LLVM::MulOp,
111                                arith::AttrConvertOverflowToLLVM>;
112 using NegFOpLowering =
113     VectorConvertToLLVMPattern<arith::NegFOp, LLVM::FNegOp,
114                                arith::AttrConvertFastMathToLLVM>;
115 using OrIOpLowering = VectorConvertToLLVMPattern<arith::OrIOp, LLVM::OrOp>;
116 using RemFOpLowering =
117     VectorConvertToLLVMPattern<arith::RemFOp, LLVM::FRemOp,
118                                arith::AttrConvertFastMathToLLVM>;
119 using RemSIOpLowering =
120     VectorConvertToLLVMPattern<arith::RemSIOp, LLVM::SRemOp>;
121 using RemUIOpLowering =
122     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
123 using SelectOpLowering =
124     VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
125 using ShLIOpLowering =
126     VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
127                                arith::AttrConvertOverflowToLLVM>;
128 using ShRSIOpLowering =
129     VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
130 using ShRUIOpLowering =
131     VectorConvertToLLVMPattern<arith::ShRUIOp, LLVM::LShrOp>;
132 using SIToFPOpLowering =
133     VectorConvertToLLVMPattern<arith::SIToFPOp, LLVM::SIToFPOp>;
134 using SubFOpLowering =
135     VectorConvertToLLVMPattern<arith::SubFOp, LLVM::FSubOp,
136                                arith::AttrConvertFastMathToLLVM>;
137 using SubIOpLowering =
138     VectorConvertToLLVMPattern<arith::SubIOp, LLVM::SubOp,
139                                arith::AttrConvertOverflowToLLVM>;
140 using TruncFOpLowering =
141     ConstrainedVectorConvertToLLVMPattern<arith::TruncFOp, LLVM::FPTruncOp,
142                                           false>;
143 using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
144     arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
145     arith::AttrConverterConstrainedFPToLLVM>;
146 using TruncIOpLowering =
147     VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
148 using UIToFPOpLowering =
149     VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
150 using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
151 
152 //===----------------------------------------------------------------------===//
153 // Op Lowering Patterns
154 //===----------------------------------------------------------------------===//
155 
156 /// Directly lower to LLVM op.
157 struct ConstantOpLowering : public ConvertOpToLLVMPattern<arith::ConstantOp> {
158   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
159 
160   LogicalResult
161   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
162                   ConversionPatternRewriter &rewriter) const override;
163 };
164 
165 /// The lowering of index_cast becomes an integer conversion since index
166 /// becomes an integer.  If the bit width of the source and target integer
167 /// types is the same, just erase the cast.  If the target type is wider,
168 /// sign-extend the value, otherwise truncate it.
169 template <typename OpTy, typename ExtCastTy>
170 struct IndexCastOpLowering : public ConvertOpToLLVMPattern<OpTy> {
171   using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
172 
173   LogicalResult
174   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
175                   ConversionPatternRewriter &rewriter) const override;
176 };
177 
178 using IndexCastOpSILowering =
179     IndexCastOpLowering<arith::IndexCastOp, LLVM::SExtOp>;
180 using IndexCastOpUILowering =
181     IndexCastOpLowering<arith::IndexCastUIOp, LLVM::ZExtOp>;
182 
183 struct AddUIExtendedOpLowering
184     : public ConvertOpToLLVMPattern<arith::AddUIExtendedOp> {
185   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
186 
187   LogicalResult
188   matchAndRewrite(arith::AddUIExtendedOp op, OpAdaptor adaptor,
189                   ConversionPatternRewriter &rewriter) const override;
190 };
191 
192 template <typename ArithMulOp, bool IsSigned>
193 struct MulIExtendedOpLowering : public ConvertOpToLLVMPattern<ArithMulOp> {
194   using ConvertOpToLLVMPattern<ArithMulOp>::ConvertOpToLLVMPattern;
195 
196   LogicalResult
197   matchAndRewrite(ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
198                   ConversionPatternRewriter &rewriter) const override;
199 };
200 
201 using MulSIExtendedOpLowering =
202     MulIExtendedOpLowering<arith::MulSIExtendedOp, true>;
203 using MulUIExtendedOpLowering =
204     MulIExtendedOpLowering<arith::MulUIExtendedOp, false>;
205 
206 struct CmpIOpLowering : public ConvertOpToLLVMPattern<arith::CmpIOp> {
207   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
208 
209   LogicalResult
210   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
211                   ConversionPatternRewriter &rewriter) const override;
212 };
213 
214 struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
215   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
216 
217   LogicalResult
218   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
219                   ConversionPatternRewriter &rewriter) const override;
220 };
221 
222 } // namespace
223 
224 //===----------------------------------------------------------------------===//
225 // ConstantOpLowering
226 //===----------------------------------------------------------------------===//
227 
228 LogicalResult
229 ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
230                                     ConversionPatternRewriter &rewriter) const {
231   return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(),
232                                        adaptor.getOperands(), op->getAttrs(),
233                                        *getTypeConverter(), rewriter);
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // IndexCastOpLowering
238 //===----------------------------------------------------------------------===//
239 
240 template <typename OpTy, typename ExtCastTy>
241 LogicalResult IndexCastOpLowering<OpTy, ExtCastTy>::matchAndRewrite(
242     OpTy op, typename OpTy::Adaptor adaptor,
243     ConversionPatternRewriter &rewriter) const {
244   Type resultType = op.getResult().getType();
245   Type targetElementType =
246       this->typeConverter->convertType(getElementTypeOrSelf(resultType));
247   Type sourceElementType =
248       this->typeConverter->convertType(getElementTypeOrSelf(op.getIn()));
249   unsigned targetBits = targetElementType.getIntOrFloatBitWidth();
250   unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth();
251 
252   if (targetBits == sourceBits) {
253     rewriter.replaceOp(op, adaptor.getIn());
254     return success();
255   }
256 
257   // Handle the scalar and 1D vector cases.
258   Type operandType = adaptor.getIn().getType();
259   if (!isa<LLVM::LLVMArrayType>(operandType)) {
260     Type targetType = this->typeConverter->convertType(resultType);
261     if (targetBits < sourceBits)
262       rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, targetType,
263                                                  adaptor.getIn());
264     else
265       rewriter.replaceOpWithNewOp<ExtCastTy>(op, targetType, adaptor.getIn());
266     return success();
267   }
268 
269   if (!isa<VectorType>(resultType))
270     return rewriter.notifyMatchFailure(op, "expected vector result type");
271 
272   return LLVM::detail::handleMultidimensionalVectors(
273       op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()),
274       [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
275         typename OpTy::Adaptor adaptor(operands);
276         if (targetBits < sourceBits) {
277           return rewriter.create<LLVM::TruncOp>(op.getLoc(), llvm1DVectorTy,
278                                                 adaptor.getIn());
279         }
280         return rewriter.create<ExtCastTy>(op.getLoc(), llvm1DVectorTy,
281                                           adaptor.getIn());
282       },
283       rewriter);
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // AddUIExtendedOpLowering
288 //===----------------------------------------------------------------------===//
289 
290 LogicalResult AddUIExtendedOpLowering::matchAndRewrite(
291     arith::AddUIExtendedOp op, OpAdaptor adaptor,
292     ConversionPatternRewriter &rewriter) const {
293   Type operandType = adaptor.getLhs().getType();
294   Type sumResultType = op.getSum().getType();
295   Type overflowResultType = op.getOverflow().getType();
296 
297   if (!LLVM::isCompatibleType(operandType))
298     return failure();
299 
300   MLIRContext *ctx = rewriter.getContext();
301   Location loc = op.getLoc();
302 
303   // Handle the scalar and 1D vector cases.
304   if (!isa<LLVM::LLVMArrayType>(operandType)) {
305     Type newOverflowType = typeConverter->convertType(overflowResultType);
306     Type structType =
307         LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType});
308     Value addOverflow = rewriter.create<LLVM::UAddWithOverflowOp>(
309         loc, structType, adaptor.getLhs(), adaptor.getRhs());
310     Value sumExtracted =
311         rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 0);
312     Value overflowExtracted =
313         rewriter.create<LLVM::ExtractValueOp>(loc, addOverflow, 1);
314     rewriter.replaceOp(op, {sumExtracted, overflowExtracted});
315     return success();
316   }
317 
318   if (!isa<VectorType>(sumResultType))
319     return rewriter.notifyMatchFailure(loc, "expected vector result types");
320 
321   return rewriter.notifyMatchFailure(loc,
322                                      "ND vector types are not supported yet");
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // MulIExtendedOpLowering
327 //===----------------------------------------------------------------------===//
328 
329 template <typename ArithMulOp, bool IsSigned>
330 LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
331     ArithMulOp op, typename ArithMulOp::Adaptor adaptor,
332     ConversionPatternRewriter &rewriter) const {
333   Type resultType = adaptor.getLhs().getType();
334 
335   if (!LLVM::isCompatibleType(resultType))
336     return failure();
337 
338   Location loc = op.getLoc();
339 
340   // Handle the scalar and 1D vector cases. Because LLVM does not have a
341   // matching extended multiplication intrinsic, perform regular multiplication
342   // on operands zero-extended to i(2*N) bits, and truncate the results back to
343   // iN types.
344   if (!isa<LLVM::LLVMArrayType>(resultType)) {
345     // Shift amount necessary to extract the high bits from widened result.
346     TypedAttr shiftValAttr;
347 
348     if (auto intTy = dyn_cast<IntegerType>(resultType)) {
349       unsigned resultBitwidth = intTy.getWidth();
350       auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
351       shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
352     } else {
353       auto vecTy = cast<VectorType>(resultType);
354       unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
355       auto attrTy = VectorType::get(
356           vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
357       shiftValAttr = SplatElementsAttr::get(
358           attrTy, APInt(resultBitwidth * 2, resultBitwidth));
359     }
360     Type wideType = shiftValAttr.getType();
361     assert(LLVM::isCompatibleType(wideType) &&
362            "LLVM dialect should support all signless integer types");
363 
364     using LLVMExtOp = std::conditional_t<IsSigned, LLVM::SExtOp, LLVM::ZExtOp>;
365     Value lhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getLhs());
366     Value rhsExt = rewriter.create<LLVMExtOp>(loc, wideType, adaptor.getRhs());
367     Value mulExt = rewriter.create<LLVM::MulOp>(loc, wideType, lhsExt, rhsExt);
368 
369     // Split the 2*N-bit wide result into two N-bit values.
370     Value low = rewriter.create<LLVM::TruncOp>(loc, resultType, mulExt);
371     Value shiftVal = rewriter.create<LLVM::ConstantOp>(loc, shiftValAttr);
372     Value highExt = rewriter.create<LLVM::LShrOp>(loc, mulExt, shiftVal);
373     Value high = rewriter.create<LLVM::TruncOp>(loc, resultType, highExt);
374 
375     rewriter.replaceOp(op, {low, high});
376     return success();
377   }
378 
379   if (!isa<VectorType>(resultType))
380     return rewriter.notifyMatchFailure(op, "expected vector result type");
381 
382   return rewriter.notifyMatchFailure(op,
383                                      "ND vector types are not supported yet");
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // CmpIOpLowering
388 //===----------------------------------------------------------------------===//
389 
390 // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums
391 // share numerical values so just cast.
392 template <typename LLVMPredType, typename PredType>
393 static LLVMPredType convertCmpPredicate(PredType pred) {
394   return static_cast<LLVMPredType>(pred);
395 }
396 
397 LogicalResult
398 CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
399                                 ConversionPatternRewriter &rewriter) const {
400   Type operandType = adaptor.getLhs().getType();
401   Type resultType = op.getResult().getType();
402 
403   // Handle the scalar and 1D vector cases.
404   if (!isa<LLVM::LLVMArrayType>(operandType)) {
405     rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
406         op, typeConverter->convertType(resultType),
407         convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
408         adaptor.getLhs(), adaptor.getRhs());
409     return success();
410   }
411 
412   if (!isa<VectorType>(resultType))
413     return rewriter.notifyMatchFailure(op, "expected vector result type");
414 
415   return LLVM::detail::handleMultidimensionalVectors(
416       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
417       [&](Type llvm1DVectorTy, ValueRange operands) {
418         OpAdaptor adaptor(operands);
419         return rewriter.create<LLVM::ICmpOp>(
420             op.getLoc(), llvm1DVectorTy,
421             convertCmpPredicate<LLVM::ICmpPredicate>(op.getPredicate()),
422             adaptor.getLhs(), adaptor.getRhs());
423       },
424       rewriter);
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // CmpFOpLowering
429 //===----------------------------------------------------------------------===//
430 
431 LogicalResult
432 CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
433                                 ConversionPatternRewriter &rewriter) const {
434   Type operandType = adaptor.getLhs().getType();
435   Type resultType = op.getResult().getType();
436   LLVM::FastmathFlags fmf =
437       arith::convertArithFastMathFlagsToLLVM(op.getFastmath());
438 
439   // Handle the scalar and 1D vector cases.
440   if (!isa<LLVM::LLVMArrayType>(operandType)) {
441     rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
442         op, typeConverter->convertType(resultType),
443         convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
444         adaptor.getLhs(), adaptor.getRhs(), fmf);
445     return success();
446   }
447 
448   if (!isa<VectorType>(resultType))
449     return rewriter.notifyMatchFailure(op, "expected vector result type");
450 
451   return LLVM::detail::handleMultidimensionalVectors(
452       op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
453       [&](Type llvm1DVectorTy, ValueRange operands) {
454         OpAdaptor adaptor(operands);
455         return rewriter.create<LLVM::FCmpOp>(
456             op.getLoc(), llvm1DVectorTy,
457             convertCmpPredicate<LLVM::FCmpPredicate>(op.getPredicate()),
458             adaptor.getLhs(), adaptor.getRhs(), fmf);
459       },
460       rewriter);
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // Pass Definition
465 //===----------------------------------------------------------------------===//
466 
467 namespace {
468 struct ArithToLLVMConversionPass
469     : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
470   using Base::Base;
471 
472   void runOnOperation() override {
473     LLVMConversionTarget target(getContext());
474     RewritePatternSet patterns(&getContext());
475 
476     LowerToLLVMOptions options(&getContext());
477     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
478       options.overrideIndexBitwidth(indexBitwidth);
479 
480     LLVMTypeConverter converter(&getContext(), options);
481     arith::populateCeilFloorDivExpandOpsPatterns(patterns);
482     arith::populateArithToLLVMConversionPatterns(converter, patterns);
483 
484     if (failed(applyPartialConversion(getOperation(), target,
485                                       std::move(patterns))))
486       signalPassFailure();
487   }
488 };
489 } // namespace
490 
491 //===----------------------------------------------------------------------===//
492 // ConvertToLLVMPatternInterface implementation
493 //===----------------------------------------------------------------------===//
494 
495 namespace {
496 /// Implement the interface to convert MemRef to LLVM.
497 struct ArithToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
498   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
499   void loadDependentDialects(MLIRContext *context) const final {
500     context->loadDialect<LLVM::LLVMDialect>();
501   }
502 
503   /// Hook for derived dialect interface to provide conversion patterns
504   /// and mark dialect legal for the conversion target.
505   void populateConvertToLLVMConversionPatterns(
506       ConversionTarget &target, LLVMTypeConverter &typeConverter,
507       RewritePatternSet &patterns) const final {
508     arith::populateCeilFloorDivExpandOpsPatterns(patterns);
509     arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
510   }
511 };
512 } // namespace
513 
514 void mlir::arith::registerConvertArithToLLVMInterface(
515     DialectRegistry &registry) {
516   registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
517     dialect->addInterfaces<ArithToLLVMDialectInterface>();
518   });
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // Pattern Population
523 //===----------------------------------------------------------------------===//
524 
525 void mlir::arith::populateArithToLLVMConversionPatterns(
526     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
527   // clang-format off
528   patterns.add<
529     AddFOpLowering,
530     AddIOpLowering,
531     AndIOpLowering,
532     AddUIExtendedOpLowering,
533     BitcastOpLowering,
534     ConstantOpLowering,
535     CmpFOpLowering,
536     CmpIOpLowering,
537     DivFOpLowering,
538     DivSIOpLowering,
539     DivUIOpLowering,
540     ExtFOpLowering,
541     ExtSIOpLowering,
542     ExtUIOpLowering,
543     FPToSIOpLowering,
544     FPToUIOpLowering,
545     IndexCastOpSILowering,
546     IndexCastOpUILowering,
547     MaximumFOpLowering,
548     MaxNumFOpLowering,
549     MaxSIOpLowering,
550     MaxUIOpLowering,
551     MinimumFOpLowering,
552     MinNumFOpLowering,
553     MinSIOpLowering,
554     MinUIOpLowering,
555     MulFOpLowering,
556     MulIOpLowering,
557     MulSIExtendedOpLowering,
558     MulUIExtendedOpLowering,
559     NegFOpLowering,
560     OrIOpLowering,
561     RemFOpLowering,
562     RemSIOpLowering,
563     RemUIOpLowering,
564     SelectOpLowering,
565     ShLIOpLowering,
566     ShRSIOpLowering,
567     ShRUIOpLowering,
568     SIToFPOpLowering,
569     SubFOpLowering,
570     SubIOpLowering,
571     TruncFOpLowering,
572     ConstrainedTruncFOpLowering,
573     TruncIOpLowering,
574     UIToFPOpLowering,
575     XOrIOpLowering
576   >(converter);
577   // clang-format on
578 }
579