1d94426d2SEugene Zhulenev //===- AlgebraicSimplification.cpp - Simplify algebraic expressions -------===//
2d94426d2SEugene Zhulenev //
3d94426d2SEugene Zhulenev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d94426d2SEugene Zhulenev // See https://llvm.org/LICENSE.txt for license information.
5d94426d2SEugene Zhulenev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d94426d2SEugene Zhulenev //
7d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===//
8d94426d2SEugene Zhulenev //
9d94426d2SEugene Zhulenev // This file implements rewrites based on the basic rules of algebra
10d94426d2SEugene Zhulenev // (Commutativity, associativity, etc...) and strength reductions for math
11d94426d2SEugene Zhulenev // operations.
12d94426d2SEugene Zhulenev //
13d94426d2SEugene Zhulenev //===----------------------------------------------------------------------===//
14d94426d2SEugene Zhulenev
15abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
16d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/IR/Math.h"
17d94426d2SEugene Zhulenev #include "mlir/Dialect/Math/Transforms/Passes.h"
1899ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
19d94426d2SEugene Zhulenev #include "mlir/IR/Builders.h"
20d94426d2SEugene Zhulenev #include "mlir/IR/Matchers.h"
21d94426d2SEugene Zhulenev #include "mlir/IR/TypeUtilities.h"
22d94426d2SEugene Zhulenev #include <climits>
23d94426d2SEugene Zhulenev
24d94426d2SEugene Zhulenev using namespace mlir;
25d94426d2SEugene Zhulenev
26d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
27d94426d2SEugene Zhulenev // PowFOp strength reduction.
28d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
29d94426d2SEugene Zhulenev
30d94426d2SEugene Zhulenev namespace {
31d94426d2SEugene Zhulenev struct PowFStrengthReduction : public OpRewritePattern<math::PowFOp> {
32d94426d2SEugene Zhulenev public:
33d94426d2SEugene Zhulenev using OpRewritePattern::OpRewritePattern;
34d94426d2SEugene Zhulenev
35d94426d2SEugene Zhulenev LogicalResult matchAndRewrite(math::PowFOp op,
36d94426d2SEugene Zhulenev PatternRewriter &rewriter) const final;
37d94426d2SEugene Zhulenev };
38d94426d2SEugene Zhulenev } // namespace
39d94426d2SEugene Zhulenev
40d94426d2SEugene Zhulenev LogicalResult
matchAndRewrite(math::PowFOp op,PatternRewriter & rewriter) const41d94426d2SEugene Zhulenev PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
42d94426d2SEugene Zhulenev PatternRewriter &rewriter) const {
43d94426d2SEugene Zhulenev Location loc = op.getLoc();
4462fea88bSJacques Pienaar Value x = op.getLhs();
45d94426d2SEugene Zhulenev
46d94426d2SEugene Zhulenev FloatAttr scalarExponent;
47d94426d2SEugene Zhulenev DenseFPElementsAttr vectorExponent;
48d94426d2SEugene Zhulenev
4962fea88bSJacques Pienaar bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
5062fea88bSJacques Pienaar bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
51d94426d2SEugene Zhulenev
52d94426d2SEugene Zhulenev // Returns true if exponent is a constant equal to `value`.
53d94426d2SEugene Zhulenev auto isExponentValue = [&](double value) -> bool {
54d94426d2SEugene Zhulenev if (isScalar)
55d94426d2SEugene Zhulenev return scalarExponent.getValue().isExactlyValue(value);
56d94426d2SEugene Zhulenev
57d94426d2SEugene Zhulenev if (isVector && vectorExponent.isSplat())
58d94426d2SEugene Zhulenev return vectorExponent.getSplatValue<FloatAttr>()
59d94426d2SEugene Zhulenev .getValue()
60d94426d2SEugene Zhulenev .isExactlyValue(value);
61d94426d2SEugene Zhulenev
62d94426d2SEugene Zhulenev return false;
63d94426d2SEugene Zhulenev };
64d94426d2SEugene Zhulenev
65d94426d2SEugene Zhulenev // Maybe broadcasts scalar value into vector type compatible with `op`.
66d94426d2SEugene Zhulenev auto bcast = [&](Value value) -> Value {
67*5550c821STres Popp if (auto vec = dyn_cast<VectorType>(op.getType()))
68d94426d2SEugene Zhulenev return rewriter.create<vector::BroadcastOp>(op.getLoc(), vec, value);
69d94426d2SEugene Zhulenev return value;
70d94426d2SEugene Zhulenev };
71d94426d2SEugene Zhulenev
72d94426d2SEugene Zhulenev // Replace `pow(x, 1.0)` with `x`.
73d94426d2SEugene Zhulenev if (isExponentValue(1.0)) {
74d94426d2SEugene Zhulenev rewriter.replaceOp(op, x);
75d94426d2SEugene Zhulenev return success();
76d94426d2SEugene Zhulenev }
77d94426d2SEugene Zhulenev
78d94426d2SEugene Zhulenev // Replace `pow(x, 2.0)` with `x * x`.
79d94426d2SEugene Zhulenev if (isExponentValue(2.0)) {
80a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
81d94426d2SEugene Zhulenev return success();
82d94426d2SEugene Zhulenev }
83d94426d2SEugene Zhulenev
84391456f3Sbakhtiyar // Replace `pow(x, 3.0)` with `x * x * x`.
85d94426d2SEugene Zhulenev if (isExponentValue(3.0)) {
86a54f4eaeSMogball Value square =
87a54f4eaeSMogball rewriter.create<arith::MulFOp>(op.getLoc(), ValueRange({x, x}));
88a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
89d94426d2SEugene Zhulenev return success();
90d94426d2SEugene Zhulenev }
91d94426d2SEugene Zhulenev
92d94426d2SEugene Zhulenev // Replace `pow(x, -1.0)` with `1.0 / x`.
93d94426d2SEugene Zhulenev if (isExponentValue(-1.0)) {
94a54f4eaeSMogball Value one = rewriter.create<arith::ConstantOp>(
95d94426d2SEugene Zhulenev loc, rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
96a54f4eaeSMogball rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
97d94426d2SEugene Zhulenev return success();
98d94426d2SEugene Zhulenev }
99d94426d2SEugene Zhulenev
100391456f3Sbakhtiyar // Replace `pow(x, 0.5)` with `sqrt(x)`.
101391456f3Sbakhtiyar if (isExponentValue(0.5)) {
102d94426d2SEugene Zhulenev rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
103d94426d2SEugene Zhulenev return success();
104d94426d2SEugene Zhulenev }
105d94426d2SEugene Zhulenev
106391456f3Sbakhtiyar // Replace `pow(x, -0.5)` with `rsqrt(x)`.
107391456f3Sbakhtiyar if (isExponentValue(-0.5)) {
108391456f3Sbakhtiyar rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
109391456f3Sbakhtiyar return success();
110391456f3Sbakhtiyar }
111391456f3Sbakhtiyar
112095ce655SSlava Zakharin // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
113095ce655SSlava Zakharin if (isExponentValue(0.75)) {
114beee5740SMehdi Amini Value powHalf = rewriter.create<math::SqrtOp>(op.getLoc(), x);
115beee5740SMehdi Amini Value powQuarter = rewriter.create<math::SqrtOp>(op.getLoc(), powHalf);
116beee5740SMehdi Amini rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
117beee5740SMehdi Amini ValueRange{powHalf, powQuarter});
118095ce655SSlava Zakharin return success();
119095ce655SSlava Zakharin }
120095ce655SSlava Zakharin
121d94426d2SEugene Zhulenev return failure();
122d94426d2SEugene Zhulenev }
123d94426d2SEugene Zhulenev
124d94426d2SEugene Zhulenev //----------------------------------------------------------------------------//
125f9d988f1SSlava Zakharin // FPowIOp/IPowIOp strength reduction.
1262dde4ba6SSlava Zakharin //----------------------------------------------------------------------------//
1272dde4ba6SSlava Zakharin
1282dde4ba6SSlava Zakharin namespace {
129f9d988f1SSlava Zakharin template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
130f9d988f1SSlava Zakharin struct PowIStrengthReduction : public OpRewritePattern<PowIOpTy> {
131f9d988f1SSlava Zakharin
1322dde4ba6SSlava Zakharin unsigned exponentThreshold;
1332dde4ba6SSlava Zakharin
1342dde4ba6SSlava Zakharin public:
PowIStrengthReduction__anonc3006bad0411::PowIStrengthReduction135f9d988f1SSlava Zakharin PowIStrengthReduction(MLIRContext *context, unsigned exponentThreshold = 3,
1362dde4ba6SSlava Zakharin PatternBenefit benefit = 1,
1372dde4ba6SSlava Zakharin ArrayRef<StringRef> generatedNames = {})
138f9d988f1SSlava Zakharin : OpRewritePattern<PowIOpTy>(context, benefit, generatedNames),
1392dde4ba6SSlava Zakharin exponentThreshold(exponentThreshold) {}
140f9d988f1SSlava Zakharin
141f9d988f1SSlava Zakharin LogicalResult matchAndRewrite(PowIOpTy op,
1422dde4ba6SSlava Zakharin PatternRewriter &rewriter) const final;
1432dde4ba6SSlava Zakharin };
1442dde4ba6SSlava Zakharin } // namespace
1452dde4ba6SSlava Zakharin
146f9d988f1SSlava Zakharin template <typename PowIOpTy, typename DivOpTy, typename MulOpTy>
1472dde4ba6SSlava Zakharin LogicalResult
matchAndRewrite(PowIOpTy op,PatternRewriter & rewriter) const148f9d988f1SSlava Zakharin PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
149f9d988f1SSlava Zakharin PowIOpTy op, PatternRewriter &rewriter) const {
1502dde4ba6SSlava Zakharin Location loc = op.getLoc();
1512dde4ba6SSlava Zakharin Value base = op.getLhs();
1522dde4ba6SSlava Zakharin
1532dde4ba6SSlava Zakharin IntegerAttr scalarExponent;
1542dde4ba6SSlava Zakharin DenseIntElementsAttr vectorExponent;
1552dde4ba6SSlava Zakharin
1562dde4ba6SSlava Zakharin bool isScalar = matchPattern(op.getRhs(), m_Constant(&scalarExponent));
1572dde4ba6SSlava Zakharin bool isVector = matchPattern(op.getRhs(), m_Constant(&vectorExponent));
1582dde4ba6SSlava Zakharin
1592dde4ba6SSlava Zakharin // Simplify cases with known exponent value.
1602dde4ba6SSlava Zakharin int64_t exponentValue = 0;
1612dde4ba6SSlava Zakharin if (isScalar)
1622dde4ba6SSlava Zakharin exponentValue = scalarExponent.getInt();
1632dde4ba6SSlava Zakharin else if (isVector && vectorExponent.isSplat())
1642dde4ba6SSlava Zakharin exponentValue = vectorExponent.getSplatValue<IntegerAttr>().getInt();
1652dde4ba6SSlava Zakharin else
1662dde4ba6SSlava Zakharin return failure();
1672dde4ba6SSlava Zakharin
1682dde4ba6SSlava Zakharin // Maybe broadcasts scalar value into vector type compatible with `op`.
169f9d988f1SSlava Zakharin auto bcast = [&loc, &op, &rewriter](Value value) -> Value {
170*5550c821STres Popp if (auto vec = dyn_cast<VectorType>(op.getType()))
1712dde4ba6SSlava Zakharin return rewriter.create<vector::BroadcastOp>(loc, vec, value);
1722dde4ba6SSlava Zakharin return value;
1732dde4ba6SSlava Zakharin };
1742dde4ba6SSlava Zakharin
175f9d988f1SSlava Zakharin Value one;
176f9d988f1SSlava Zakharin Type opType = getElementTypeOrSelf(op.getType());
177f9d988f1SSlava Zakharin if constexpr (std::is_same_v<PowIOpTy, math::FPowIOp>)
178f9d988f1SSlava Zakharin one = rewriter.create<arith::ConstantOp>(
179f9d988f1SSlava Zakharin loc, rewriter.getFloatAttr(opType, 1.0));
180f9d988f1SSlava Zakharin else
181f9d988f1SSlava Zakharin one = rewriter.create<arith::ConstantOp>(
182f9d988f1SSlava Zakharin loc, rewriter.getIntegerAttr(opType, 1));
183f9d988f1SSlava Zakharin
184f9d988f1SSlava Zakharin // Replace `[fi]powi(x, 0)` with `1`.
1852dde4ba6SSlava Zakharin if (exponentValue == 0) {
1862dde4ba6SSlava Zakharin rewriter.replaceOp(op, bcast(one));
1872dde4ba6SSlava Zakharin return success();
1882dde4ba6SSlava Zakharin }
1892dde4ba6SSlava Zakharin
1902dde4ba6SSlava Zakharin bool exponentIsNegative = false;
1912dde4ba6SSlava Zakharin if (exponentValue < 0) {
1922dde4ba6SSlava Zakharin exponentIsNegative = true;
1932dde4ba6SSlava Zakharin exponentValue *= -1;
1942dde4ba6SSlava Zakharin }
1952dde4ba6SSlava Zakharin
1962dde4ba6SSlava Zakharin // Bail out if `abs(exponent)` exceeds the threshold.
1972dde4ba6SSlava Zakharin if (exponentValue > exponentThreshold)
1982dde4ba6SSlava Zakharin return failure();
1992dde4ba6SSlava Zakharin
2002dde4ba6SSlava Zakharin // Inverse the base for negative exponent, i.e. for
201f9d988f1SSlava Zakharin // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
202f9d988f1SSlava Zakharin if (exponentIsNegative)
203f9d988f1SSlava Zakharin base = rewriter.create<DivOpTy>(loc, bcast(one), base);
2042dde4ba6SSlava Zakharin
2052dde4ba6SSlava Zakharin Value result = base;
2062dde4ba6SSlava Zakharin // Transform to naive sequence of multiplications:
2072dde4ba6SSlava Zakharin // * For positive exponent case replace:
208f9d988f1SSlava Zakharin // `[fi]powi(x, positive_exponent)`
2092dde4ba6SSlava Zakharin // with:
2102dde4ba6SSlava Zakharin // x * x * x * ...
2112dde4ba6SSlava Zakharin // * For negative exponent case replace:
212f9d988f1SSlava Zakharin // `[fi]powi(x, negative_exponent)`
2132dde4ba6SSlava Zakharin // with:
2142dde4ba6SSlava Zakharin // (1 / x) * (1 / x) * (1 / x) * ...
2152dde4ba6SSlava Zakharin for (unsigned i = 1; i < exponentValue; ++i)
216f9d988f1SSlava Zakharin result = rewriter.create<MulOpTy>(loc, result, base);
2172dde4ba6SSlava Zakharin
2182dde4ba6SSlava Zakharin rewriter.replaceOp(op, result);
2192dde4ba6SSlava Zakharin return success();
2202dde4ba6SSlava Zakharin }
2212dde4ba6SSlava Zakharin
2222dde4ba6SSlava Zakharin //----------------------------------------------------------------------------//
223d94426d2SEugene Zhulenev
populateMathAlgebraicSimplificationPatterns(RewritePatternSet & patterns)224d94426d2SEugene Zhulenev void mlir::populateMathAlgebraicSimplificationPatterns(
225d94426d2SEugene Zhulenev RewritePatternSet &patterns) {
226f9d988f1SSlava Zakharin patterns
227f9d988f1SSlava Zakharin .add<PowFStrengthReduction,
228f9d988f1SSlava Zakharin PowIStrengthReduction<math::IPowIOp, arith::DivSIOp, arith::MulIOp>,
229f9d988f1SSlava Zakharin PowIStrengthReduction<math::FPowIOp, arith::DivFOp, arith::MulFOp>>(
2302dde4ba6SSlava Zakharin patterns.getContext());
231d94426d2SEugene Zhulenev }
232