xref: /llvm-project/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp (revision e553ac4d8148291914526f4f66f09e362ce0a63f)
1 //===- AttrToLLVMConverter.cpp - Arith attributes conversion to LLVM ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
10 
11 using namespace mlir;
12 
13 LLVM::FastmathFlags
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)14 mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
15   LLVM::FastmathFlags llvmFMF{};
16   const std::pair<arith::FastMathFlags, LLVM::FastmathFlags> flags[] = {
17       {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan},
18       {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf},
19       {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz},
20       {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp},
21       {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
22       {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
23       {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
24   for (auto [arithFlag, llvmFlag] : flags) {
25     if (bitEnumContainsAny(arithFMF, arithFlag))
26       llvmFMF = llvmFMF | llvmFlag;
27   }
28   return llvmFMF;
29 }
30 
31 LLVM::FastmathFlagsAttr
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr)32 mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
33   arith::FastMathFlags arithFMF = fmfAttr.getValue();
34   return LLVM::FastmathFlagsAttr::get(
35       fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
36 }
37 
convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags)38 LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
39     arith::IntegerOverflowFlags arithFlags) {
40   LLVM::IntegerOverflowFlags llvmFlags{};
41   const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
42       flags[] = {
43           {arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
44           {arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
45   for (auto [arithFlag, llvmFlag] : flags) {
46     if (bitEnumContainsAny(arithFlags, arithFlag))
47       llvmFlags = llvmFlags | llvmFlag;
48   }
49   return llvmFlags;
50 }
51 
52 LLVM::RoundingMode
convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode)53 mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
54   switch (roundingMode) {
55   case arith::RoundingMode::downward:
56     return LLVM::RoundingMode::TowardNegative;
57   case arith::RoundingMode::to_nearest_away:
58     return LLVM::RoundingMode::NearestTiesToAway;
59   case arith::RoundingMode::to_nearest_even:
60     return LLVM::RoundingMode::NearestTiesToEven;
61   case arith::RoundingMode::toward_zero:
62     return LLVM::RoundingMode::TowardZero;
63   case arith::RoundingMode::upward:
64     return LLVM::RoundingMode::TowardPositive;
65   }
66   llvm_unreachable("Unhandled rounding mode");
67 }
68 
convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr)69 LLVM::RoundingModeAttr mlir::arith::convertArithRoundingModeAttrToLLVM(
70     arith::RoundingModeAttr roundingModeAttr) {
71   assert(roundingModeAttr && "Expecting valid attribute");
72   return LLVM::RoundingModeAttr::get(
73       roundingModeAttr.getContext(),
74       convertArithRoundingModeToLLVM(roundingModeAttr.getValue()));
75 }
76 
77 LLVM::FPExceptionBehaviorAttr
getLLVMDefaultFPExceptionBehavior(MLIRContext & context)78 mlir::arith::getLLVMDefaultFPExceptionBehavior(MLIRContext &context) {
79   return LLVM::FPExceptionBehaviorAttr::get(&context,
80                                             LLVM::FPExceptionBehavior::Ignore);
81 }
82