xref: /llvm-project/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1ee8b8d6bSIvan Butygin //===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===//
2ee8b8d6bSIvan Butygin //
3ee8b8d6bSIvan Butygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ee8b8d6bSIvan Butygin // See https://llvm.org/LICENSE.txt for license information.
5ee8b8d6bSIvan Butygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ee8b8d6bSIvan Butygin //
7ee8b8d6bSIvan Butygin //===----------------------------------------------------------------------===//
8ee8b8d6bSIvan Butygin //
9ee8b8d6bSIvan Butygin // This file implements uplifting from arith ops to math.fma.
10ee8b8d6bSIvan Butygin //
11ee8b8d6bSIvan Butygin //===----------------------------------------------------------------------===//
12ee8b8d6bSIvan Butygin 
13ee8b8d6bSIvan Butygin #include "mlir/Dialect/Arith/IR/Arith.h"
14ee8b8d6bSIvan Butygin #include "mlir/Dialect/Math/IR/Math.h"
15ee8b8d6bSIvan Butygin #include "mlir/Dialect/Math/Transforms/Passes.h"
16ee8b8d6bSIvan Butygin #include "mlir/IR/PatternMatch.h"
17ee8b8d6bSIvan Butygin #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18ee8b8d6bSIvan Butygin 
19ee8b8d6bSIvan Butygin namespace mlir::math {
20ee8b8d6bSIvan Butygin #define GEN_PASS_DEF_MATHUPLIFTTOFMA
21ee8b8d6bSIvan Butygin #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
22ee8b8d6bSIvan Butygin } // namespace mlir::math
23ee8b8d6bSIvan Butygin 
24ee8b8d6bSIvan Butygin using namespace mlir;
25ee8b8d6bSIvan Butygin 
26ee8b8d6bSIvan Butygin template <typename Op>
27ee8b8d6bSIvan Butygin static bool isValidForFMA(Op op) {
28ee8b8d6bSIvan Butygin   return static_cast<bool>(op.getFastmath() & arith::FastMathFlags::contract);
29ee8b8d6bSIvan Butygin }
30ee8b8d6bSIvan Butygin 
31ee8b8d6bSIvan Butygin namespace {
32ee8b8d6bSIvan Butygin 
33ee8b8d6bSIvan Butygin struct UpliftFma final : OpRewritePattern<arith::AddFOp> {
34ee8b8d6bSIvan Butygin   using OpRewritePattern::OpRewritePattern;
35ee8b8d6bSIvan Butygin 
36ee8b8d6bSIvan Butygin   LogicalResult matchAndRewrite(arith::AddFOp op,
37ee8b8d6bSIvan Butygin                                 PatternRewriter &rewriter) const override {
38ee8b8d6bSIvan Butygin     if (!isValidForFMA(op))
39ee8b8d6bSIvan Butygin       return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma");
40ee8b8d6bSIvan Butygin 
41ee8b8d6bSIvan Butygin     Value c;
42ee8b8d6bSIvan Butygin     arith::MulFOp ab;
43ee8b8d6bSIvan Butygin     if ((ab = op.getLhs().getDefiningOp<arith::MulFOp>())) {
44ee8b8d6bSIvan Butygin       c = op.getRhs();
45ee8b8d6bSIvan Butygin     } else if ((ab = op.getRhs().getDefiningOp<arith::MulFOp>())) {
46ee8b8d6bSIvan Butygin       c = op.getLhs();
47ee8b8d6bSIvan Butygin     } else {
48ee8b8d6bSIvan Butygin       return rewriter.notifyMatchFailure(op, "no mulf op");
49ee8b8d6bSIvan Butygin     }
50ee8b8d6bSIvan Butygin 
51ee8b8d6bSIvan Butygin     if (!isValidForFMA(ab))
52ee8b8d6bSIvan Butygin       return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma");
53ee8b8d6bSIvan Butygin 
54ee8b8d6bSIvan Butygin     Value a = ab.getLhs();
55ee8b8d6bSIvan Butygin     Value b = ab.getRhs();
56ee8b8d6bSIvan Butygin     arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath();
57ee8b8d6bSIvan Butygin     rewriter.replaceOpWithNewOp<math::FmaOp>(op, a, b, c, fmf);
58ee8b8d6bSIvan Butygin     return success();
59ee8b8d6bSIvan Butygin   }
60ee8b8d6bSIvan Butygin };
61ee8b8d6bSIvan Butygin 
62ee8b8d6bSIvan Butygin struct MathUpliftToFMA final
63ee8b8d6bSIvan Butygin     : math::impl::MathUpliftToFMABase<MathUpliftToFMA> {
64ee8b8d6bSIvan Butygin   using MathUpliftToFMABase::MathUpliftToFMABase;
65ee8b8d6bSIvan Butygin 
66ee8b8d6bSIvan Butygin   void runOnOperation() override {
67ee8b8d6bSIvan Butygin     RewritePatternSet patterns(&getContext());
68ee8b8d6bSIvan Butygin     populateUpliftToFMAPatterns(patterns);
69*09dfc571SJacques Pienaar     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
70ee8b8d6bSIvan Butygin       return signalPassFailure();
71ee8b8d6bSIvan Butygin   }
72ee8b8d6bSIvan Butygin };
73ee8b8d6bSIvan Butygin 
74ee8b8d6bSIvan Butygin } // namespace
75ee8b8d6bSIvan Butygin 
76ee8b8d6bSIvan Butygin void mlir::populateUpliftToFMAPatterns(RewritePatternSet &patterns) {
77ee8b8d6bSIvan Butygin   patterns.insert<UpliftFma>(patterns.getContext());
78ee8b8d6bSIvan Butygin }
79