1 //===- AttrToLLVMConverter.h - Arith attributes conversion ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H 10 #define MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H 11 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 15 //===----------------------------------------------------------------------===// 16 // Support for converting Arith FastMathFlags to LLVM FastmathFlags 17 //===----------------------------------------------------------------------===// 18 19 namespace mlir { 20 namespace arith { 21 /// Maps arithmetic fastmath enum values to LLVM enum values. 22 LLVM::FastmathFlags 23 convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF); 24 25 /// Creates an LLVM fastmath attribute from a given arithmetic fastmath 26 /// attribute. 27 LLVM::FastmathFlagsAttr 28 convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr); 29 30 /// Maps arithmetic overflow enum values to LLVM enum values. 31 LLVM::IntegerOverflowFlags 32 convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags); 33 34 /// Creates an LLVM rounding mode enum value from a given arithmetic rounding 35 /// mode enum value. 36 LLVM::RoundingMode 37 convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode); 38 39 /// Creates an LLVM rounding mode attribute from a given arithmetic rounding 40 /// mode attribute. 41 LLVM::RoundingModeAttr 42 convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr); 43 44 /// Returns an attribute for the default LLVM FP exception behavior. 45 LLVM::FPExceptionBehaviorAttr 46 getLLVMDefaultFPExceptionBehavior(MLIRContext &context); 47 48 // Attribute converter that populates a NamedAttrList by removing the fastmath 49 // attribute from the source operation attributes, and replacing it with an 50 // equivalent LLVM fastmath attribute. 51 template <typename SourceOp, typename TargetOp> 52 class AttrConvertFastMathToLLVM { 53 public: AttrConvertFastMathToLLVM(SourceOp srcOp)54 AttrConvertFastMathToLLVM(SourceOp srcOp) { 55 // Copy the source attributes. 56 convertedAttr = NamedAttrList{srcOp->getAttrs()}; 57 // Get the name of the arith fastmath attribute. 58 StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); 59 // Remove the source fastmath attribute. 60 auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>( 61 convertedAttr.erase(arithFMFAttrName)); 62 if (arithFMFAttr) { 63 StringRef targetAttrName = TargetOp::getFastmathAttrName(); 64 convertedAttr.set(targetAttrName, 65 convertArithFastMathAttrToLLVM(arithFMFAttr)); 66 } 67 } 68 getAttrs()69 ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } getOverflowFlags()70 LLVM::IntegerOverflowFlags getOverflowFlags() const { 71 return LLVM::IntegerOverflowFlags::none; 72 } 73 74 private: 75 NamedAttrList convertedAttr; 76 }; 77 78 // Attribute converter that populates a NamedAttrList by removing the overflow 79 // attribute from the source operation attributes, and replacing it with an 80 // equivalent LLVM overflow attribute. 81 template <typename SourceOp, typename TargetOp> 82 class AttrConvertOverflowToLLVM { 83 public: AttrConvertOverflowToLLVM(SourceOp srcOp)84 AttrConvertOverflowToLLVM(SourceOp srcOp) { 85 // Copy the source attributes. 86 convertedAttr = NamedAttrList{srcOp->getAttrs()}; 87 // Get the name of the arith overflow attribute. 88 StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); 89 // Remove the source overflow attribute. 90 if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>( 91 convertedAttr.erase(arithAttrName))) { 92 overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue()); 93 } 94 } 95 getAttrs()96 ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } getOverflowFlags()97 LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; } 98 99 private: 100 NamedAttrList convertedAttr; 101 LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none; 102 }; 103 104 template <typename SourceOp, typename TargetOp> 105 class AttrConverterConstrainedFPToLLVM { 106 static_assert(TargetOp::template hasTrait< 107 LLVM::FPExceptionBehaviorOpInterface::Trait>(), 108 "Target constrained FP operations must implement " 109 "LLVM::FPExceptionBehaviorOpInterface"); 110 111 public: AttrConverterConstrainedFPToLLVM(SourceOp srcOp)112 AttrConverterConstrainedFPToLLVM(SourceOp srcOp) { 113 // Copy the source attributes. 114 convertedAttr = NamedAttrList{srcOp->getAttrs()}; 115 116 if constexpr (TargetOp::template hasTrait< 117 LLVM::RoundingModeOpInterface::Trait>()) { 118 // Get the name of the rounding mode attribute. 119 StringRef arithAttrName = srcOp.getRoundingModeAttrName(); 120 // Remove the source attribute. 121 auto arithAttr = 122 cast<arith::RoundingModeAttr>(convertedAttr.erase(arithAttrName)); 123 // Set the target attribute. 124 convertedAttr.set(TargetOp::getRoundingModeAttrName(), 125 convertArithRoundingModeAttrToLLVM(arithAttr)); 126 } 127 convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(), 128 getLLVMDefaultFPExceptionBehavior(*srcOp->getContext())); 129 } 130 getAttrs()131 ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } getOverflowFlags()132 LLVM::IntegerOverflowFlags getOverflowFlags() const { 133 return LLVM::IntegerOverflowFlags::none; 134 } 135 136 private: 137 NamedAttrList convertedAttr; 138 }; 139 140 } // namespace arith 141 } // namespace mlir 142 143 #endif // MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H 144