1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of tanh op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/Math/Transforms/Passes.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 /// Create a float constant. 26 static Value createFloatConst(Location loc, Type type, double value, 27 OpBuilder &b) { 28 auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value); 29 if (auto shapedTy = dyn_cast<ShapedType>(type)) { 30 return b.create<arith::ConstantOp>(loc, 31 DenseElementsAttr::get(shapedTy, attr)); 32 } 33 34 return b.create<arith::ConstantOp>(loc, attr); 35 } 36 37 /// Create a float constant. 38 static Value createIntConst(Location loc, Type type, int64_t value, 39 OpBuilder &b) { 40 auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); 41 if (auto shapedTy = dyn_cast<ShapedType>(type)) { 42 return b.create<arith::ConstantOp>(loc, 43 DenseElementsAttr::get(shapedTy, attr)); 44 } 45 46 return b.create<arith::ConstantOp>(loc, attr); 47 } 48 49 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { 50 Type opType = operand.getType(); 51 Type i64Ty = b.getI64Type(); 52 if (auto shapedTy = dyn_cast<ShapedType>(opType)) 53 i64Ty = shapedTy.clone(i64Ty); 54 Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand); 55 Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert); 56 // The truncation does not preserve the sign when the truncated 57 // value is -0. So here the sign is copied again. 58 return b.create<math::CopySignOp>(fpFixedConvert, operand); 59 } 60 61 /// Expands tanh op into 62 /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 63 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 64 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { 65 auto floatType = op.getOperand().getType(); 66 Location loc = op.getLoc(); 67 Value one = createFloatConst(loc, floatType, 1.0, rewriter); 68 Value two = createFloatConst(loc, floatType, 2.0, rewriter); 69 Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two); 70 71 // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} 72 Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX); 73 Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX); 74 Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x); 75 Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x); 76 Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 77 78 // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 79 exp2x = rewriter.create<math::ExpOp>(loc, doubledX); 80 dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one); 81 divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one); 82 Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 83 84 // tanh(x) = x >= 0 ? positiveRes : negativeRes 85 Value zero = createFloatConst(loc, floatType, 0.0, rewriter); 86 Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, 87 op.getOperand(), zero); 88 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes, 89 negativeRes); 90 return success(); 91 } 92 93 // Converts math.tan to math.sin, math.cos, and arith.divf. 94 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { 95 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 96 Value operand = op.getOperand(); 97 Type type = operand.getType(); 98 Value sin = b.create<math::SinOp>(type, operand); 99 Value cos = b.create<math::CosOp>(type, operand); 100 Value div = b.create<arith::DivFOp>(type, sin, cos); 101 rewriter.replaceOp(op, div); 102 return success(); 103 } 104 105 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { 106 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 107 Value operandA = op.getOperand(0); 108 Value operandB = op.getOperand(1); 109 Value operandC = op.getOperand(2); 110 Type type = op.getType(); 111 Value mult = b.create<arith::MulFOp>(type, operandA, operandB); 112 Value add = b.create<arith::AddFOp>(type, mult, operandC); 113 rewriter.replaceOp(op, add); 114 return success(); 115 } 116 117 // Converts a floorf() function to the following: 118 // floorf(float x) -> 119 // y = (float)(int) x 120 // if (x < 0) then incr = -1 else incr = 0 121 // y = y + incr <= replace this op with the floorf op. 122 static LogicalResult convertFloorOp(math::FloorOp op, 123 PatternRewriter &rewriter) { 124 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 125 Value operand = op.getOperand(); 126 Type opType = operand.getType(); 127 Value fpFixedConvert = createTruncatedFPValue(operand, b); 128 129 // Creating constants for later use. 130 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 131 Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); 132 133 Value negCheck = 134 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); 135 Value incrValue = 136 b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero); 137 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); 138 rewriter.replaceOp(op, ret); 139 return success(); 140 } 141 142 // Converts a ceilf() function to the following: 143 // ceilf(float x) -> 144 // y = (float)(int) x 145 // if (x > y) then incr = 1 else incr = 0 146 // y = y + incr <= replace this op with the ceilf op. 147 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { 148 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 149 Value operand = op.getOperand(); 150 Type opType = operand.getType(); 151 Value fpFixedConvert = createTruncatedFPValue(operand, b); 152 153 // Creating constants for later use. 154 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 155 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); 156 157 Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, 158 fpFixedConvert); 159 Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero); 160 161 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); 162 rewriter.replaceOp(op, ret); 163 return success(); 164 } 165 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) 166 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) { 167 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 168 Value operandA = op.getOperand(0); 169 Value operandB = op.getOperand(1); 170 Type opType = operandA.getType(); 171 172 Value logA = b.create<math::LogOp>(opType, operandA); 173 Value mult = b.create<arith::MulFOp>(opType, logA, operandB); 174 Value expResult = b.create<math::ExpOp>(opType, mult); 175 rewriter.replaceOp(op, expResult); 176 return success(); 177 } 178 179 // exp2f(float x) -> exp(x * ln(2)) 180 // Proof: Let's say 2^x = y 181 // ln(2^x) = ln(y) 182 // x * ln(2) = ln(y) => e ^(x*ln(2)) = y 183 static LogicalResult convertExp2fOp(math::Exp2Op op, 184 PatternRewriter &rewriter) { 185 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 186 Value operand = op.getOperand(); 187 Type opType = operand.getType(); 188 Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b); 189 Value mult = b.create<arith::MulFOp>(opType, operand, ln2); 190 Value exp = b.create<math::ExpOp>(op->getLoc(), mult); 191 rewriter.replaceOp(op, exp); 192 return success(); 193 } 194 195 static LogicalResult convertRoundOp(math::RoundOp op, 196 PatternRewriter &rewriter) { 197 Location loc = op.getLoc(); 198 ImplicitLocOpBuilder b(loc, rewriter); 199 Value operand = op.getOperand(); 200 Type opType = operand.getType(); 201 Type opEType = getElementTypeOrSelf(opType); 202 203 if (!opEType.isF32()) { 204 return rewriter.notifyMatchFailure(op, "not a round of f32."); 205 } 206 207 Type i32Ty = b.getI32Type(); 208 if (auto shapedTy = dyn_cast<ShapedType>(opType)) 209 i32Ty = shapedTy.clone(i32Ty); 210 211 Value half = createFloatConst(loc, opType, 0.5, b); 212 Value c23 = createIntConst(loc, i32Ty, 23, b); 213 Value c127 = createIntConst(loc, i32Ty, 127, b); 214 Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b); 215 216 Value incrValue = b.create<math::CopySignOp>(half, operand); 217 Value add = b.create<arith::AddFOp>(opType, operand, incrValue); 218 Value fpFixedConvert = createTruncatedFPValue(add, b); 219 220 // There are three cases where adding 0.5 to the value and truncating by 221 // converting to an i64 does not result in the correct behavior: 222 // 223 // 1. Special values: +-inf and +-nan 224 // Casting these special values to i64 has undefined behavior. To identify 225 // these values, we use the fact that these values are the only float 226 // values with the maximum possible biased exponent. 227 // 228 // 2. Large values: 2^23 <= |x| <= INT_64_MAX 229 // Adding 0.5 to a float larger than or equal to 2^23 results in precision 230 // errors that sometimes round the value up and sometimes round the value 231 // down. For example: 232 // 8388608.0 + 0.5 = 8388608.0 233 // 8388609.0 + 0.5 = 8388610.0 234 // 235 // 3. Very large values: |x| > INT_64_MAX 236 // Casting to i64 a value greater than the max i64 value will overflow the 237 // i64 leading to wrong outputs. 238 // 239 // All three cases satisfy the property `biasedExp >= 23`. 240 Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand); 241 Value operandExp = b.create<arith::AndIOp>( 242 b.create<arith::ShRUIOp>(operandBitcast, c23), expMask); 243 Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127); 244 Value isSpecialValOrLargeVal = 245 b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23); 246 247 Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand, 248 fpFixedConvert); 249 rewriter.replaceOp(op, result); 250 return success(); 251 } 252 253 // Converts math.ctlz to scf and arith operations. This is done 254 // by performing a binary search on the bits. 255 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, 256 PatternRewriter &rewriter) { 257 auto operand = op.getOperand(); 258 auto operandTy = operand.getType(); 259 auto eTy = getElementTypeOrSelf(operandTy); 260 Location loc = op.getLoc(); 261 262 int32_t bitwidth = eTy.getIntOrFloatBitWidth(); 263 if (bitwidth > 64) 264 return failure(); 265 266 uint64_t allbits = -1; 267 if (bitwidth < 64) { 268 allbits = allbits >> (64 - bitwidth); 269 } 270 271 Value x = operand; 272 Value count = createIntConst(loc, operandTy, 0, rewriter); 273 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) { 274 auto half = bw / 2; 275 auto bits = createIntConst(loc, operandTy, half, rewriter); 276 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); 277 278 Value pred = 279 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask); 280 Value add = rewriter.create<arith::AddIOp>(loc, count, bits); 281 Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits); 282 283 x = rewriter.create<arith::SelectOp>(loc, pred, shift, x); 284 count = rewriter.create<arith::SelectOp>(loc, pred, add, count); 285 } 286 287 Value zero = createIntConst(loc, operandTy, 0, rewriter); 288 Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 289 operand, zero); 290 291 Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); 292 Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count); 293 rewriter.replaceOp(op, sel); 294 return success(); 295 } 296 297 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { 298 patterns.add(convertCtlzOp); 299 } 300 301 void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { 302 patterns.add(convertTanOp); 303 } 304 305 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { 306 patterns.add(convertTanhOp); 307 } 308 309 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { 310 patterns.add(convertFmaFOp); 311 } 312 313 void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { 314 patterns.add(convertCeilOp); 315 } 316 317 void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { 318 patterns.add(convertExp2fOp); 319 } 320 321 void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { 322 patterns.add(convertPowfOp); 323 } 324 325 void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { 326 patterns.add(convertRoundOp); 327 } 328 329 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { 330 patterns.add(convertFloorOp); 331 } 332