1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of tanh op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/Math/Transforms/Passes.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/Dialect/Vector/IR/VectorOps.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/ImplicitLocOpBuilder.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 23 using namespace mlir; 24 25 /// Create a float constant. 26 static Value createFloatConst(Location loc, Type type, double value, 27 OpBuilder &b) { 28 auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value); 29 if (auto shapedTy = dyn_cast<ShapedType>(type)) { 30 return b.create<arith::ConstantOp>(loc, 31 DenseElementsAttr::get(shapedTy, attr)); 32 } 33 34 return b.create<arith::ConstantOp>(loc, attr); 35 } 36 37 /// Create a float constant. 38 static Value createIntConst(Location loc, Type type, int64_t value, 39 OpBuilder &b) { 40 auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); 41 if (auto shapedTy = dyn_cast<ShapedType>(type)) { 42 return b.create<arith::ConstantOp>(loc, 43 DenseElementsAttr::get(shapedTy, attr)); 44 } 45 46 return b.create<arith::ConstantOp>(loc, attr); 47 } 48 49 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) { 50 Type opType = operand.getType(); 51 Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand); 52 Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert); 53 return fpFixedConvert; 54 } 55 56 /// Expands tanh op into 57 /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 58 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 59 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { 60 auto floatType = op.getOperand().getType(); 61 Location loc = op.getLoc(); 62 Value one = createFloatConst(loc, floatType, 1.0, rewriter); 63 Value two = createFloatConst(loc, floatType, 2.0, rewriter); 64 Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two); 65 66 // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} 67 Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX); 68 Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX); 69 Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x); 70 Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x); 71 Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 72 73 // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 74 exp2x = rewriter.create<math::ExpOp>(loc, doubledX); 75 dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one); 76 divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one); 77 Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 78 79 // tanh(x) = x >= 0 ? positiveRes : negativeRes 80 Value zero = createFloatConst(loc, floatType, 0.0, rewriter); 81 Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, 82 op.getOperand(), zero); 83 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes, 84 negativeRes); 85 return success(); 86 } 87 88 // Converts math.tan to math.sin, math.cos, and arith.divf. 89 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { 90 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 91 Value operand = op.getOperand(); 92 Type type = operand.getType(); 93 Value sin = b.create<math::SinOp>(type, operand); 94 Value cos = b.create<math::CosOp>(type, operand); 95 Value div = b.create<arith::DivFOp>(type, sin, cos); 96 rewriter.replaceOp(op, div); 97 return success(); 98 } 99 100 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) { 101 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 102 Value operandA = op.getOperand(0); 103 Value operandB = op.getOperand(1); 104 Value operandC = op.getOperand(2); 105 Type type = op.getType(); 106 Value mult = b.create<arith::MulFOp>(type, operandA, operandB); 107 Value add = b.create<arith::AddFOp>(type, mult, operandC); 108 rewriter.replaceOp(op, add); 109 return success(); 110 } 111 112 // Converts a floorf() function to the following: 113 // floorf(float x) -> 114 // y = (float)(int) x 115 // if (x < 0) then incr = -1 else incr = 0 116 // y = y + incr <= replace this op with the floorf op. 117 static LogicalResult convertFloorOp(math::FloorOp op, 118 PatternRewriter &rewriter) { 119 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 120 Value operand = op.getOperand(); 121 Type opType = operand.getType(); 122 Value fpFixedConvert = createTruncatedFPValue(operand, b); 123 124 // Creating constants for later use. 125 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 126 Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter); 127 128 Value negCheck = 129 b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero); 130 Value incrValue = 131 b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero); 132 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); 133 rewriter.replaceOp(op, ret); 134 return success(); 135 } 136 137 // Converts a ceilf() function to the following: 138 // ceilf(float x) -> 139 // y = (float)(int) x 140 // if (x > y) then incr = 1 else incr = 0 141 // y = y + incr <= replace this op with the ceilf op. 142 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { 143 ImplicitLocOpBuilder b(op->getLoc(), rewriter); 144 Value operand = op.getOperand(); 145 Type opType = operand.getType(); 146 Value fpFixedConvert = createTruncatedFPValue(operand, b); 147 148 // Creating constants for later use. 149 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter); 150 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter); 151 152 Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, 153 fpFixedConvert); 154 Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero); 155 156 Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue); 157 rewriter.replaceOp(op, ret); 158 return success(); 159 } 160 161 // Converts math.ctlz to scf and arith operations. This is done 162 // by performing a binary search on the bits. 163 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, 164 PatternRewriter &rewriter) { 165 auto operand = op.getOperand(); 166 auto operandTy = operand.getType(); 167 auto eTy = getElementTypeOrSelf(operandTy); 168 Location loc = op.getLoc(); 169 170 int32_t bitwidth = eTy.getIntOrFloatBitWidth(); 171 if (bitwidth > 64) 172 return failure(); 173 174 uint64_t allbits = -1; 175 if (bitwidth < 64) { 176 allbits = allbits >> (64 - bitwidth); 177 } 178 179 Value x = operand; 180 Value count = createIntConst(loc, operandTy, 0, rewriter); 181 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) { 182 auto half = bw / 2; 183 auto bits = createIntConst(loc, operandTy, half, rewriter); 184 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); 185 186 Value pred = 187 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask); 188 Value add = rewriter.create<arith::AddIOp>(loc, count, bits); 189 Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits); 190 191 x = rewriter.create<arith::SelectOp>(loc, pred, shift, x); 192 count = rewriter.create<arith::SelectOp>(loc, pred, add, count); 193 } 194 195 Value zero = createIntConst(loc, operandTy, 0, rewriter); 196 Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 197 operand, zero); 198 199 Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); 200 Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count); 201 rewriter.replaceOp(op, sel); 202 return success(); 203 } 204 205 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { 206 patterns.add(convertCtlzOp); 207 } 208 209 void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { 210 patterns.add(convertTanOp); 211 } 212 213 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { 214 patterns.add(convertTanhOp); 215 } 216 217 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { 218 patterns.add(convertFmaFOp); 219 } 220 221 void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { 222 patterns.add(convertCeilOp); 223 } 224 225 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) { 226 patterns.add(convertFloorOp); 227 } 228