xref: /llvm-project/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (revision 4da96515ea8552cdf14c6aa6310d2a91fbe74641)
1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of tanh op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/Math/Transforms/Passes.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/Vector/IR/VectorOps.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/ImplicitLocOpBuilder.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Transforms/DialectConversion.h"
22 
23 using namespace mlir;
24 
25 /// Create a float constant.
26 static Value createFloatConst(Location loc, Type type, double value,
27                               OpBuilder &b) {
28   auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value);
29   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
30     return b.create<arith::ConstantOp>(loc,
31                                        DenseElementsAttr::get(shapedTy, attr));
32   }
33 
34   return b.create<arith::ConstantOp>(loc, attr);
35 }
36 
37 /// Create a float constant.
38 static Value createIntConst(Location loc, Type type, int64_t value,
39                             OpBuilder &b) {
40   auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
41   if (auto shapedTy = dyn_cast<ShapedType>(type)) {
42     return b.create<arith::ConstantOp>(loc,
43                                        DenseElementsAttr::get(shapedTy, attr));
44   }
45 
46   return b.create<arith::ConstantOp>(loc, attr);
47 }
48 
49 static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b) {
50   Type opType = operand.getType();
51   Value fixedConvert = b.create<arith::FPToSIOp>(b.getI64Type(), operand);
52   Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
53   return fpFixedConvert;
54 }
55 
56 /// Expands tanh op into
57 ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
58 ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
59 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
60   auto floatType = op.getOperand().getType();
61   Location loc = op.getLoc();
62   Value one = createFloatConst(loc, floatType, 1.0, rewriter);
63   Value two = createFloatConst(loc, floatType, 2.0, rewriter);
64   Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
65 
66   // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
67   Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
68   Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
69   Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
70   Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
71   Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
72 
73   // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
74   exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
75   dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
76   divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
77   Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
78 
79   // tanh(x) = x >= 0 ? positiveRes : negativeRes
80   Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
81   Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
82                                                 op.getOperand(), zero);
83   rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
84                                                negativeRes);
85   return success();
86 }
87 
88 // Converts math.tan to math.sin, math.cos, and arith.divf.
89 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
90   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
91   Value operand = op.getOperand();
92   Type type = operand.getType();
93   Value sin = b.create<math::SinOp>(type, operand);
94   Value cos = b.create<math::CosOp>(type, operand);
95   Value div = b.create<arith::DivFOp>(type, sin, cos);
96   rewriter.replaceOp(op, div);
97   return success();
98 }
99 
100 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
101   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
102   Value operandA = op.getOperand(0);
103   Value operandB = op.getOperand(1);
104   Value operandC = op.getOperand(2);
105   Type type = op.getType();
106   Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
107   Value add = b.create<arith::AddFOp>(type, mult, operandC);
108   rewriter.replaceOp(op, add);
109   return success();
110 }
111 
112 // Converts a floorf() function to the following:
113 // floorf(float x) ->
114 //     y = (float)(int) x
115 //     if (x < 0) then incr = -1 else incr = 0
116 //     y = y + incr    <= replace this op with the floorf op.
117 static LogicalResult convertFloorOp(math::FloorOp op,
118                                     PatternRewriter &rewriter) {
119   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
120   Value operand = op.getOperand();
121   Type opType = operand.getType();
122   Value fpFixedConvert = createTruncatedFPValue(operand, b);
123 
124   // Creating constants for later use.
125   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
126   Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
127 
128   Value negCheck =
129       b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
130   Value incrValue =
131       b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
132   Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
133   rewriter.replaceOp(op, ret);
134   return success();
135 }
136 
137 // Converts a ceilf() function to the following:
138 // ceilf(float x) ->
139 //      y = (float)(int) x
140 //      if (x > y) then incr = 1 else incr = 0
141 //      y = y + incr   <= replace this op with the ceilf op.
142 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
143   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
144   Value operand = op.getOperand();
145   Type opType = operand.getType();
146   Value fpFixedConvert = createTruncatedFPValue(operand, b);
147 
148   // Creating constants for later use.
149   Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
150   Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
151 
152   Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
153                                           fpFixedConvert);
154   Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
155 
156   Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
157   rewriter.replaceOp(op, ret);
158   return success();
159 }
160 
161 // exp2f(float x) -> exp(x * ln(2))
162 //   Proof: Let's say 2^x = y
163 //   ln(2^x) = ln(y)
164 //   x * ln(2) = ln(y) => e ^(x*ln(2)) = y
165 static LogicalResult convertExp2fOp(math::Exp2Op op,
166                                     PatternRewriter &rewriter) {
167   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
168   Value operand = op.getOperand();
169   Type opType = operand.getType();
170   Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
171   Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
172   Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
173   rewriter.replaceOp(op, exp);
174   return success();
175 }
176 
177 // Converts math.ctlz to scf and arith operations. This is done
178 // by performing a binary search on the bits.
179 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
180                                    PatternRewriter &rewriter) {
181   auto operand = op.getOperand();
182   auto operandTy = operand.getType();
183   auto eTy = getElementTypeOrSelf(operandTy);
184   Location loc = op.getLoc();
185 
186   int32_t bitwidth = eTy.getIntOrFloatBitWidth();
187   if (bitwidth > 64)
188     return failure();
189 
190   uint64_t allbits = -1;
191   if (bitwidth < 64) {
192     allbits = allbits >> (64 - bitwidth);
193   }
194 
195   Value x = operand;
196   Value count = createIntConst(loc, operandTy, 0, rewriter);
197   for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
198     auto half = bw / 2;
199     auto bits = createIntConst(loc, operandTy, half, rewriter);
200     auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
201 
202     Value pred =
203         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
204     Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
205     Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
206 
207     x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
208     count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
209   }
210 
211   Value zero = createIntConst(loc, operandTy, 0, rewriter);
212   Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
213                                               operand, zero);
214 
215   Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
216   Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
217   rewriter.replaceOp(op, sel);
218   return success();
219 }
220 
221 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
222   patterns.add(convertCtlzOp);
223 }
224 
225 void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
226   patterns.add(convertTanOp);
227 }
228 
229 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
230   patterns.add(convertTanhOp);
231 }
232 
233 void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
234   patterns.add(convertFmaFOp);
235 }
236 
237 void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
238   patterns.add(convertCeilOp);
239 }
240 
241 void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
242   patterns.add(convertExp2fOp);
243 }
244 
245 void mlir::populateExpandFloorFPattern(RewritePatternSet &patterns) {
246   patterns.add(convertFloorOp);
247 }
248