xref: /llvm-project/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
1 //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements rewrites based on the basic rules of algebra
10 // (Commutativity, associativity, etc...) and strength reductions for math
11 // operations.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/Math/Transforms/Passes.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include <climits>
23 
24 using namespace mlir;
25 
26 //----------------------------------------------------------------------------//
27 // PowFOp strength reduction.
28 //----------------------------------------------------------------------------//
29 
30 namespace {
31 struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32 public:
33   using OpRewritePattern::OpRewritePattern;
34 
35   LogicalResult matchAndRewrite(math::PowFOp op,
36                                 PatternRewriter &rewriter) const final;
37 };
38 } // namespace
39 
40 LogicalResult
matchAndRewrite(math::PowFOp op,PatternRewriter & rewriter) const41 PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42                                        PatternRewriter &rewriter) const {
43   Location loc = op.getLoc();
44   Value x = op.getLhs();
45 
46   FloatAttr scalarExponent;
47   DenseFPElementsAttr vectorExponent;
48 
49   bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
50   bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
51 
52   // Returns true if exponent is a constant equal to `value`.
53   auto isExponentValue = [&](double value) -> bool {
54     if (isScalar)
55       return scalarExponent.getValue().isExactlyValue(value);
56 
57     if (isVector && vectorExponent.isSplat())
58       return vectorExponent.getSplatValue<FloatAttr>()
59           .getValue()
60           .isExactlyValue(value);
61 
62     return false;
63   };
64 
65   // Maybe broadcasts scalar value into vector type compatible with `op`.
66   auto bcast = [&](Value value) -> Value {
67     if (auto vec = dyn_cast<VectorType>(op.getType()))
68       return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
69     return value;
70   };
71 
72   // Replace `pow(x, 1.0)` with `x`.
73   if (isExponentValue(1.0)) {
74     rewriter.replaceOp(op, x);
75     return success();
76   }
77 
78   // Replace `pow(x, 2.0)` with `x * x`.
79   if (isExponentValue(2.0)) {
80     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81     return success();
82   }
83 
84   // Replace `pow(x, 3.0)` with `x * x * x`.
85   if (isExponentValue(3.0)) {
86     Value square =
87         rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
88     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89     return success();
90   }
91 
92   // Replace `pow(x, -1.0)` with `1.0 / x`.
93   if (isExponentValue(-1.0)) {
94     Value one = rewriter.create<arith::ConstantOp>(
95         loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
96     rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
97     return success();
98   }
99 
100   // Replace `pow(x, 0.5)` with `sqrt(x)`.
101   if (isExponentValue(0.5)) {
102     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
103     return success();
104   }
105 
106   // Replace `pow(x, -0.5)` with `rsqrt(x)`.
107   if (isExponentValue(-0.5)) {
108     rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
109     return success();
110   }
111 
112   // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
113   if (isExponentValue(0.75)) {
114     Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
115     Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
116     rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
117                                                ValueRange{powHalf, powQuarter});
118     return success();
119   }
120 
121   return failure();
122 }
123 
124 //----------------------------------------------------------------------------//
125 // FPowIOp/IPowIOp strength reduction.
126 //----------------------------------------------------------------------------//
127 
128 namespace {
129 template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
130 struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
131 
132   unsigned exponentThreshold;
133 
134 public:
PowIStrengthReduction__anonc3006bad0411::PowIStrengthReduction135   PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
136                         PatternBenefit benefit = 1,
137                         ArrayRef<StringRef> generatedNames = {})
138       : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
139         exponentThreshold(exponentThreshold) {}
140 
141   LogicalResult matchAndRewrite(PowIOpTy op,
142                                 PatternRewriter &rewriter) const final;
143 };
144 } // namespace
145 
146 template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
147 LogicalResult
matchAndRewrite(PowIOpTy op,PatternRewriter & rewriter) const148 PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
149     PowIOpTy op, PatternRewriter &rewriter) const {
150   Location loc = op.getLoc();
151   Value base = op.getLhs();
152 
153   IntegerAttr scalarExponent;
154   DenseIntElementsAttr vectorExponent;
155 
156   bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
157   bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
158 
159   // Simplify cases with known exponent value.
160   int64_t exponentValue = 0;
161   if (isScalar)
162     exponentValue = scalarExponent.getInt();
163   else if (isVector && vectorExponent.isSplat())
164     exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
165   else
166     return failure();
167 
168   // Maybe broadcasts scalar value into vector type compatible with `op`.
169   auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
170     if (auto vec = dyn_cast<VectorType>(op.getType()))
171       return rewriter.create<vector::BroadcastOp>(loc, vec, value);
172     return value;
173   };
174 
175   Value one;
176   Type opType = getElementTypeOrSelf(op.getType());
177   if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
178     one = rewriter.create<arith::ConstantOp>(
179         loc, rewriter.getFloatAttr(opType, 1.0));
180   else
181     one = rewriter.create<arith::ConstantOp>(
182         loc, rewriter.getIntegerAttr(opType, 1));
183 
184   // Replace `[fi]powi(x, 0)` with `1`.
185   if (exponentValue == 0) {
186     rewriter.replaceOp(op, bcast(one));
187     return success();
188   }
189 
190   bool exponentIsNegative = false;
191   if (exponentValue < 0) {
192     exponentIsNegative = true;
193     exponentValue *= -1;
194   }
195 
196   // Bail out if `abs(exponent)` exceeds the threshold.
197   if (exponentValue > exponentThreshold)
198     return failure();
199 
200   // Inverse the base for negative exponent, i.e. for
201   // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
202   if (exponentIsNegative)
203     base = rewriter.create<DivOpTy>(loc, bcast(one), base);
204 
205   Value result = base;
206   // Transform to naive sequence of multiplications:
207   //   * For positive exponent case replace:
208   //       `[fi]powi(x, positive_exponent)`
209   //     with:
210   //       x * x * x * ...
211   //   * For negative exponent case replace:
212   //       `[fi]powi(x, negative_exponent)`
213   //     with:
214   //       (1 / x) * (1 / x) * (1 / x) * ...
215   for (unsigned i = 1; i < exponentValue; ++i)
216     result = rewriter.create<MulOpTy>(loc, result, base);
217 
218   rewriter.replaceOp(op, result);
219   return success();
220 }
221 
222 //----------------------------------------------------------------------------//
223 
populateMathAlgebraicSimplificationPatterns(RewritePatternSet & patterns)224 void mlir::populateMathAlgebraicSimplificationPatterns(
225     RewritePatternSet &patterns) {
226   patterns
227       .add<PowFStrengthReduction,
228            PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
229            PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
230           patterns.getContext());
231 }
232