xref: /llvm-project/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (revision 5550c821897ab77e664977121a0e90ad5be1ff59)
1d94426d2SEugene Zhulenev //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2d94426d2SEugene Zhulenev //
3d94426d2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d94426d2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5d94426d2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d94426d2SEugene Zhulenev //
7d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===//
8d94426d2SEugene Zhulenev //
9d94426d2SEugene Zhulenev // This file implements rewrites based on the basic rules of algebra
10d94426d2SEugene Zhulenev // (Commutativity, associativity, etc...) and strength reductions for math
11d94426d2SEugene Zhulenev // operations.
12d94426d2SEugene Zhulenev //
13d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===//
14d94426d2SEugene Zhulenev 
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
16d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/IR/Math.h"
17d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/Transforms/Passes.h"
1899ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
19d94426d2SEugene Zhulenev #include "mlir/IR/Builders.h"
20d94426d2SEugene Zhulenev #include "mlir/IR/Matchers.h"
21d94426d2SEugene Zhulenev #include "mlir/IR/TypeUtilities.h"
22d94426d2SEugene Zhulenev #include <climits>
23d94426d2SEugene Zhulenev 
24d94426d2SEugene Zhulenev using namespace mlir;
25d94426d2SEugene Zhulenev 
26d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
27d94426d2SEugene Zhulenev // PowFOp strength reduction.
28d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
29d94426d2SEugene Zhulenev 
30d94426d2SEugene Zhulenev namespace {
31d94426d2SEugene Zhulenev struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32d94426d2SEugene Zhulenev public:
33d94426d2SEugene Zhulenev   using OpRewritePattern::OpRewritePattern;
34d94426d2SEugene Zhulenev 
35d94426d2SEugene Zhulenev   LogicalResult matchAndRewrite(math::PowFOp op,
36d94426d2SEugene Zhulenev                                 PatternRewriter &rewriter) const final;
37d94426d2SEugene Zhulenev };
38d94426d2SEugene Zhulenev } // namespace
39d94426d2SEugene Zhulenev 
40d94426d2SEugene Zhulenev LogicalResult
matchAndRewrite(math::PowFOp op,PatternRewriter & rewriter) const41d94426d2SEugene Zhulenev PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42d94426d2SEugene Zhulenev                                        PatternRewriter &rewriter) const {
43d94426d2SEugene Zhulenev   Location loc = op.getLoc();
4462fea88bSJacques Pienaar   Value x = op.getLhs();
45d94426d2SEugene Zhulenev 
46d94426d2SEugene Zhulenev   FloatAttr scalarExponent;
47d94426d2SEugene Zhulenev   DenseFPElementsAttr vectorExponent;
48d94426d2SEugene Zhulenev 
4962fea88bSJacques Pienaar   bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
5062fea88bSJacques Pienaar   bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
51d94426d2SEugene Zhulenev 
52d94426d2SEugene Zhulenev   // Returns true if exponent is a constant equal to `value`.
53d94426d2SEugene Zhulenev   auto isExponentValue = [&](double value) -> bool {
54d94426d2SEugene Zhulenev     if (isScalar)
55d94426d2SEugene Zhulenev       return scalarExponent.getValue().isExactlyValue(value);
56d94426d2SEugene Zhulenev 
57d94426d2SEugene Zhulenev     if (isVector && vectorExponent.isSplat())
58d94426d2SEugene Zhulenev       return vectorExponent.getSplatValue<FloatAttr>()
59d94426d2SEugene Zhulenev           .getValue()
60d94426d2SEugene Zhulenev           .isExactlyValue(value);
61d94426d2SEugene Zhulenev 
62d94426d2SEugene Zhulenev     return false;
63d94426d2SEugene Zhulenev   };
64d94426d2SEugene Zhulenev 
65d94426d2SEugene Zhulenev   // Maybe broadcasts scalar value into vector type compatible with `op`.
66d94426d2SEugene Zhulenev   auto bcast = [&](Value value) -> Value {
67*5550c821STres Popp     if (auto vec = dyn_cast<VectorType>(op.getType()))
68d94426d2SEugene Zhulenev       return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
69d94426d2SEugene Zhulenev     return value;
70d94426d2SEugene Zhulenev   };
71d94426d2SEugene Zhulenev 
72d94426d2SEugene Zhulenev   // Replace `pow(x, 1.0)` with `x`.
73d94426d2SEugene Zhulenev   if (isExponentValue(1.0)) {
74d94426d2SEugene Zhulenev     rewriter.replaceOp(op, x);
75d94426d2SEugene Zhulenev     return success();
76d94426d2SEugene Zhulenev   }
77d94426d2SEugene Zhulenev 
78d94426d2SEugene Zhulenev   // Replace `pow(x, 2.0)` with `x * x`.
79d94426d2SEugene Zhulenev   if (isExponentValue(2.0)) {
80a54f4eaeSMogball     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81d94426d2SEugene Zhulenev     return success();
82d94426d2SEugene Zhulenev   }
83d94426d2SEugene Zhulenev 
84391456f3Sbakhtiyar   // Replace `pow(x, 3.0)` with `x * x * x`.
85d94426d2SEugene Zhulenev   if (isExponentValue(3.0)) {
86a54f4eaeSMogball     Value square =
87a54f4eaeSMogball         rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
88a54f4eaeSMogball     rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89d94426d2SEugene Zhulenev     return success();
90d94426d2SEugene Zhulenev   }
91d94426d2SEugene Zhulenev 
92d94426d2SEugene Zhulenev   // Replace `pow(x, -1.0)` with `1.0 / x`.
93d94426d2SEugene Zhulenev   if (isExponentValue(-1.0)) {
94a54f4eaeSMogball     Value one = rewriter.create<arith::ConstantOp>(
95d94426d2SEugene Zhulenev         loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
96a54f4eaeSMogball     rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
97d94426d2SEugene Zhulenev     return success();
98d94426d2SEugene Zhulenev   }
99d94426d2SEugene Zhulenev 
100391456f3Sbakhtiyar   // Replace `pow(x, 0.5)` with `sqrt(x)`.
101391456f3Sbakhtiyar   if (isExponentValue(0.5)) {
102d94426d2SEugene Zhulenev     rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
103d94426d2SEugene Zhulenev     return success();
104d94426d2SEugene Zhulenev   }
105d94426d2SEugene Zhulenev 
106391456f3Sbakhtiyar   // Replace `pow(x, -0.5)` with `rsqrt(x)`.
107391456f3Sbakhtiyar   if (isExponentValue(-0.5)) {
108391456f3Sbakhtiyar     rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
109391456f3Sbakhtiyar     return success();
110391456f3Sbakhtiyar   }
111391456f3Sbakhtiyar 
112095ce655SSlava Zakharin   // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
113095ce655SSlava Zakharin   if (isExponentValue(0.75)) {
114beee5740SMehdi Amini     Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
115beee5740SMehdi Amini     Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
116beee5740SMehdi Amini     rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
117beee5740SMehdi Amini                                                ValueRange{powHalf, powQuarter});
118095ce655SSlava Zakharin     return success();
119095ce655SSlava Zakharin   }
120095ce655SSlava Zakharin 
121d94426d2SEugene Zhulenev   return failure();
122d94426d2SEugene Zhulenev }
123d94426d2SEugene Zhulenev 
124d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
125f9d988f1SSlava Zakharin // FPowIOp/IPowIOp strength reduction.
1262dde4ba6SSlava Zakharin //----------------------------------------------------------------------------//
1272dde4ba6SSlava Zakharin 
1282dde4ba6SSlava Zakharin namespace {
129f9d988f1SSlava Zakharin template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
130f9d988f1SSlava Zakharin struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
131f9d988f1SSlava Zakharin 
1322dde4ba6SSlava Zakharin   unsigned exponentThreshold;
1332dde4ba6SSlava Zakharin 
1342dde4ba6SSlava Zakharin public:
PowIStrengthReduction__anonc3006bad0411::PowIStrengthReduction135f9d988f1SSlava Zakharin   PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
1362dde4ba6SSlava Zakharin                         PatternBenefit benefit = 1,
1372dde4ba6SSlava Zakharin                         ArrayRef<StringRef> generatedNames = {})
138f9d988f1SSlava Zakharin       : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
1392dde4ba6SSlava Zakharin         exponentThreshold(exponentThreshold) {}
140f9d988f1SSlava Zakharin 
141f9d988f1SSlava Zakharin   LogicalResult matchAndRewrite(PowIOpTy op,
1422dde4ba6SSlava Zakharin                                 PatternRewriter &rewriter) const final;
1432dde4ba6SSlava Zakharin };
1442dde4ba6SSlava Zakharin } // namespace
1452dde4ba6SSlava Zakharin 
146f9d988f1SSlava Zakharin template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
1472dde4ba6SSlava Zakharin LogicalResult
matchAndRewrite(PowIOpTy op,PatternRewriter & rewriter) const148f9d988f1SSlava Zakharin PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
149f9d988f1SSlava Zakharin     PowIOpTy op, PatternRewriter &rewriter) const {
1502dde4ba6SSlava Zakharin   Location loc = op.getLoc();
1512dde4ba6SSlava Zakharin   Value base = op.getLhs();
1522dde4ba6SSlava Zakharin 
1532dde4ba6SSlava Zakharin   IntegerAttr scalarExponent;
1542dde4ba6SSlava Zakharin   DenseIntElementsAttr vectorExponent;
1552dde4ba6SSlava Zakharin 
1562dde4ba6SSlava Zakharin   bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
1572dde4ba6SSlava Zakharin   bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
1582dde4ba6SSlava Zakharin 
1592dde4ba6SSlava Zakharin   // Simplify cases with known exponent value.
1602dde4ba6SSlava Zakharin   int64_t exponentValue = 0;
1612dde4ba6SSlava Zakharin   if (isScalar)
1622dde4ba6SSlava Zakharin     exponentValue = scalarExponent.getInt();
1632dde4ba6SSlava Zakharin   else if (isVector && vectorExponent.isSplat())
1642dde4ba6SSlava Zakharin     exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
1652dde4ba6SSlava Zakharin   else
1662dde4ba6SSlava Zakharin     return failure();
1672dde4ba6SSlava Zakharin 
1682dde4ba6SSlava Zakharin   // Maybe broadcasts scalar value into vector type compatible with `op`.
169f9d988f1SSlava Zakharin   auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
170*5550c821STres Popp     if (auto vec = dyn_cast<VectorType>(op.getType()))
1712dde4ba6SSlava Zakharin       return rewriter.create<vector::BroadcastOp>(loc, vec, value);
1722dde4ba6SSlava Zakharin     return value;
1732dde4ba6SSlava Zakharin   };
1742dde4ba6SSlava Zakharin 
175f9d988f1SSlava Zakharin   Value one;
176f9d988f1SSlava Zakharin   Type opType = getElementTypeOrSelf(op.getType());
177f9d988f1SSlava Zakharin   if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
178f9d988f1SSlava Zakharin     one = rewriter.create<arith::ConstantOp>(
179f9d988f1SSlava Zakharin         loc, rewriter.getFloatAttr(opType, 1.0));
180f9d988f1SSlava Zakharin   else
181f9d988f1SSlava Zakharin     one = rewriter.create<arith::ConstantOp>(
182f9d988f1SSlava Zakharin         loc, rewriter.getIntegerAttr(opType, 1));
183f9d988f1SSlava Zakharin 
184f9d988f1SSlava Zakharin   // Replace `[fi]powi(x, 0)` with `1`.
1852dde4ba6SSlava Zakharin   if (exponentValue == 0) {
1862dde4ba6SSlava Zakharin     rewriter.replaceOp(op, bcast(one));
1872dde4ba6SSlava Zakharin     return success();
1882dde4ba6SSlava Zakharin   }
1892dde4ba6SSlava Zakharin 
1902dde4ba6SSlava Zakharin   bool exponentIsNegative = false;
1912dde4ba6SSlava Zakharin   if (exponentValue < 0) {
1922dde4ba6SSlava Zakharin     exponentIsNegative = true;
1932dde4ba6SSlava Zakharin     exponentValue *= -1;
1942dde4ba6SSlava Zakharin   }
1952dde4ba6SSlava Zakharin 
1962dde4ba6SSlava Zakharin   // Bail out if `abs(exponent)` exceeds the threshold.
1972dde4ba6SSlava Zakharin   if (exponentValue > exponentThreshold)
1982dde4ba6SSlava Zakharin     return failure();
1992dde4ba6SSlava Zakharin 
2002dde4ba6SSlava Zakharin   // Inverse the base for negative exponent, i.e. for
201f9d988f1SSlava Zakharin   // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
202f9d988f1SSlava Zakharin   if (exponentIsNegative)
203f9d988f1SSlava Zakharin     base = rewriter.create<DivOpTy>(loc, bcast(one), base);
2042dde4ba6SSlava Zakharin 
2052dde4ba6SSlava Zakharin   Value result = base;
2062dde4ba6SSlava Zakharin   // Transform to naive sequence of multiplications:
2072dde4ba6SSlava Zakharin   //   * For positive exponent case replace:
208f9d988f1SSlava Zakharin   //       `[fi]powi(x, positive_exponent)`
2092dde4ba6SSlava Zakharin   //     with:
2102dde4ba6SSlava Zakharin   //       x * x * x * ...
2112dde4ba6SSlava Zakharin   //   * For negative exponent case replace:
212f9d988f1SSlava Zakharin   //       `[fi]powi(x, negative_exponent)`
2132dde4ba6SSlava Zakharin   //     with:
2142dde4ba6SSlava Zakharin   //       (1 / x) * (1 / x) * (1 / x) * ...
2152dde4ba6SSlava Zakharin   for (unsigned i = 1; i < exponentValue; ++i)
216f9d988f1SSlava Zakharin     result = rewriter.create<MulOpTy>(loc, result, base);
2172dde4ba6SSlava Zakharin 
2182dde4ba6SSlava Zakharin   rewriter.replaceOp(op, result);
2192dde4ba6SSlava Zakharin   return success();
2202dde4ba6SSlava Zakharin }
2212dde4ba6SSlava Zakharin 
2222dde4ba6SSlava Zakharin //----------------------------------------------------------------------------//
223d94426d2SEugene Zhulenev 
populateMathAlgebraicSimplificationPatterns(RewritePatternSet & patterns)224d94426d2SEugene Zhulenev void mlir::populateMathAlgebraicSimplificationPatterns(
225d94426d2SEugene Zhulenev     RewritePatternSet &patterns) {
226f9d988f1SSlava Zakharin   patterns
227f9d988f1SSlava Zakharin       .add<PowFStrengthReduction,
228f9d988f1SSlava Zakharin            PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
229f9d988f1SSlava Zakharin            PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
2302dde4ba6SSlava Zakharin           patterns.getContext());
231d94426d2SEugene Zhulenev }
232