1ee8b8d6bSIvan Butygin //===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===// 2ee8b8d6bSIvan Butygin // 3ee8b8d6bSIvan Butygin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4ee8b8d6bSIvan Butygin // See https://llvm.org/LICENSE.txt for license information. 5ee8b8d6bSIvan Butygin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6ee8b8d6bSIvan Butygin // 7ee8b8d6bSIvan Butygin //===----------------------------------------------------------------------===// 8ee8b8d6bSIvan Butygin // 9ee8b8d6bSIvan Butygin // This file implements uplifting from arith ops to math.fma. 10ee8b8d6bSIvan Butygin // 11ee8b8d6bSIvan Butygin //===----------------------------------------------------------------------===// 12ee8b8d6bSIvan Butygin 13ee8b8d6bSIvan Butygin #include "mlir/Dialect/Arith/IR/Arith.h" 14ee8b8d6bSIvan Butygin #include "mlir/Dialect/Math/IR/Math.h" 15ee8b8d6bSIvan Butygin #include "mlir/Dialect/Math/Transforms/Passes.h" 16ee8b8d6bSIvan Butygin #include "mlir/IR/PatternMatch.h" 17ee8b8d6bSIvan Butygin #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18ee8b8d6bSIvan Butygin 19ee8b8d6bSIvan Butygin namespace mlir::math { 20ee8b8d6bSIvan Butygin #define GEN_PASS_DEF_MATHUPLIFTTOFMA 21ee8b8d6bSIvan Butygin #include "mlir/Dialect/Math/Transforms/Passes.h.inc" 22ee8b8d6bSIvan Butygin } // namespace mlir::math 23ee8b8d6bSIvan Butygin 24ee8b8d6bSIvan Butygin using namespace mlir; 25ee8b8d6bSIvan Butygin 26ee8b8d6bSIvan Butygin template <typename Op> 27ee8b8d6bSIvan Butygin static bool isValidForFMA(Op op) { 28ee8b8d6bSIvan Butygin return static_cast<bool>(op.getFastmath() & arith::FastMathFlags::contract); 29ee8b8d6bSIvan Butygin } 30ee8b8d6bSIvan Butygin 31ee8b8d6bSIvan Butygin namespace { 32ee8b8d6bSIvan Butygin 33ee8b8d6bSIvan Butygin struct UpliftFma final : OpRewritePattern<arith::AddFOp> { 34ee8b8d6bSIvan Butygin using OpRewritePattern::OpRewritePattern; 35ee8b8d6bSIvan Butygin 36ee8b8d6bSIvan Butygin LogicalResult matchAndRewrite(arith::AddFOp op, 37ee8b8d6bSIvan Butygin PatternRewriter &rewriter) const override { 38ee8b8d6bSIvan Butygin if (!isValidForFMA(op)) 39ee8b8d6bSIvan Butygin return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma"); 40ee8b8d6bSIvan Butygin 41ee8b8d6bSIvan Butygin Value c; 42ee8b8d6bSIvan Butygin arith::MulFOp ab; 43ee8b8d6bSIvan Butygin if ((ab = op.getLhs().getDefiningOp<arith::MulFOp>())) { 44ee8b8d6bSIvan Butygin c = op.getRhs(); 45ee8b8d6bSIvan Butygin } else if ((ab = op.getRhs().getDefiningOp<arith::MulFOp>())) { 46ee8b8d6bSIvan Butygin c = op.getLhs(); 47ee8b8d6bSIvan Butygin } else { 48ee8b8d6bSIvan Butygin return rewriter.notifyMatchFailure(op, "no mulf op"); 49ee8b8d6bSIvan Butygin } 50ee8b8d6bSIvan Butygin 51ee8b8d6bSIvan Butygin if (!isValidForFMA(ab)) 52ee8b8d6bSIvan Butygin return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma"); 53ee8b8d6bSIvan Butygin 54ee8b8d6bSIvan Butygin Value a = ab.getLhs(); 55ee8b8d6bSIvan Butygin Value b = ab.getRhs(); 56ee8b8d6bSIvan Butygin arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath(); 57ee8b8d6bSIvan Butygin rewriter.replaceOpWithNewOp<math::FmaOp>(op, a, b, c, fmf); 58ee8b8d6bSIvan Butygin return success(); 59ee8b8d6bSIvan Butygin } 60ee8b8d6bSIvan Butygin }; 61ee8b8d6bSIvan Butygin 62ee8b8d6bSIvan Butygin struct MathUpliftToFMA final 63ee8b8d6bSIvan Butygin : math::impl::MathUpliftToFMABase<MathUpliftToFMA> { 64ee8b8d6bSIvan Butygin using MathUpliftToFMABase::MathUpliftToFMABase; 65ee8b8d6bSIvan Butygin 66ee8b8d6bSIvan Butygin void runOnOperation() override { 67ee8b8d6bSIvan Butygin RewritePatternSet patterns(&getContext()); 68ee8b8d6bSIvan Butygin populateUpliftToFMAPatterns(patterns); 69*09dfc571SJacques Pienaar if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 70ee8b8d6bSIvan Butygin return signalPassFailure(); 71ee8b8d6bSIvan Butygin } 72ee8b8d6bSIvan Butygin }; 73ee8b8d6bSIvan Butygin 74ee8b8d6bSIvan Butygin } // namespace 75ee8b8d6bSIvan Butygin 76ee8b8d6bSIvan Butygin void mlir::populateUpliftToFMAPatterns(RewritePatternSet &patterns) { 77ee8b8d6bSIvan Butygin patterns.insert<UpliftFma>(patterns.getContext()); 78ee8b8d6bSIvan Butygin } 79