xref: /llvm-project/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h (revision e553ac4d8148291914526f4f66f09e362ce0a63f)
1 //===- AttrToLLVMConverter.h - Arith attributes conversion ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
10 #define MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
11 
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 
15 //===----------------------------------------------------------------------===//
16 // Support for converting Arith FastMathFlags to LLVM FastmathFlags
17 //===----------------------------------------------------------------------===//
18 
19 namespace mlir {
20 namespace arith {
21 /// Maps arithmetic fastmath enum values to LLVM enum values.
22 LLVM::FastmathFlags
23 convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
24 
25 /// Creates an LLVM fastmath attribute from a given arithmetic fastmath
26 /// attribute.
27 LLVM::FastmathFlagsAttr
28 convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
29 
30 /// Maps arithmetic overflow enum values to LLVM enum values.
31 LLVM::IntegerOverflowFlags
32 convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
33 
34 /// Creates an LLVM rounding mode enum value from a given arithmetic rounding
35 /// mode enum value.
36 LLVM::RoundingMode
37 convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode);
38 
39 /// Creates an LLVM rounding mode attribute from a given arithmetic rounding
40 /// mode attribute.
41 LLVM::RoundingModeAttr
42 convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr);
43 
44 /// Returns an attribute for the default LLVM FP exception behavior.
45 LLVM::FPExceptionBehaviorAttr
46 getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
47 
48 // Attribute converter that populates a NamedAttrList by removing the fastmath
49 // attribute from the source operation attributes, and replacing it with an
50 // equivalent LLVM fastmath attribute.
51 template <typename SourceOp, typename TargetOp>
52 class AttrConvertFastMathToLLVM {
53 public:
AttrConvertFastMathToLLVM(SourceOp srcOp)54   AttrConvertFastMathToLLVM(SourceOp srcOp) {
55     // Copy the source attributes.
56     convertedAttr = NamedAttrList{srcOp->getAttrs()};
57     // Get the name of the arith fastmath attribute.
58     StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
59     // Remove the source fastmath attribute.
60     auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
61         convertedAttr.erase(arithFMFAttrName));
62     if (arithFMFAttr) {
63       StringRef targetAttrName = TargetOp::getFastmathAttrName();
64       convertedAttr.set(targetAttrName,
65                         convertArithFastMathAttrToLLVM(arithFMFAttr));
66     }
67   }
68 
getAttrs()69   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
getOverflowFlags()70   LLVM::IntegerOverflowFlags getOverflowFlags() const {
71     return LLVM::IntegerOverflowFlags::none;
72   }
73 
74 private:
75   NamedAttrList convertedAttr;
76 };
77 
78 // Attribute converter that populates a NamedAttrList by removing the overflow
79 // attribute from the source operation attributes, and replacing it with an
80 // equivalent LLVM overflow attribute.
81 template <typename SourceOp, typename TargetOp>
82 class AttrConvertOverflowToLLVM {
83 public:
AttrConvertOverflowToLLVM(SourceOp srcOp)84   AttrConvertOverflowToLLVM(SourceOp srcOp) {
85     // Copy the source attributes.
86     convertedAttr = NamedAttrList{srcOp->getAttrs()};
87     // Get the name of the arith overflow attribute.
88     StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
89     // Remove the source overflow attribute.
90     if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
91             convertedAttr.erase(arithAttrName))) {
92       overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
93     }
94   }
95 
getAttrs()96   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
getOverflowFlags()97   LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
98 
99 private:
100   NamedAttrList convertedAttr;
101   LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
102 };
103 
104 template <typename SourceOp, typename TargetOp>
105 class AttrConverterConstrainedFPToLLVM {
106   static_assert(TargetOp::template hasTrait<
107                     LLVM::FPExceptionBehaviorOpInterface::Trait>(),
108                 "Target constrained FP operations must implement "
109                 "LLVM::FPExceptionBehaviorOpInterface");
110 
111 public:
AttrConverterConstrainedFPToLLVM(SourceOp srcOp)112   AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
113     // Copy the source attributes.
114     convertedAttr = NamedAttrList{srcOp->getAttrs()};
115 
116     if constexpr (TargetOp::template hasTrait<
117                       LLVM::RoundingModeOpInterface::Trait>()) {
118       // Get the name of the rounding mode attribute.
119       StringRef arithAttrName = srcOp.getRoundingModeAttrName();
120       // Remove the source attribute.
121       auto arithAttr =
122           cast<arith::RoundingModeAttr>(convertedAttr.erase(arithAttrName));
123       // Set the target attribute.
124       convertedAttr.set(TargetOp::getRoundingModeAttrName(),
125                         convertArithRoundingModeAttrToLLVM(arithAttr));
126     }
127     convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
128                       getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
129   }
130 
getAttrs()131   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
getOverflowFlags()132   LLVM::IntegerOverflowFlags getOverflowFlags() const {
133     return LLVM::IntegerOverflowFlags::none;
134   }
135 
136 private:
137   NamedAttrList convertedAttr;
138 };
139 
140 } // namespace arith
141 } // namespace mlir
142 
143 #endif // MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
144