1 //===- UpliftToFMA.cpp - Arith to FMA uplifting ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements uplifting from arith ops to math.fma. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/Math/Transforms/Passes.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 19 namespace mlir::math { 20 #define GEN_PASS_DEF_MATHUPLIFTTOFMA 21 #include "mlir/Dialect/Math/Transforms/Passes.h.inc" 22 } // namespace mlir::math 23 24 using namespace mlir; 25 26 template <typename Op> 27 static bool isValidForFMA(Op op) { 28 return static_cast<bool>(op.getFastmath() & arith::FastMathFlags::contract); 29 } 30 31 namespace { 32 33 struct UpliftFma final : OpRewritePattern<arith::AddFOp> { 34 using OpRewritePattern::OpRewritePattern; 35 36 LogicalResult matchAndRewrite(arith::AddFOp op, 37 PatternRewriter &rewriter) const override { 38 if (!isValidForFMA(op)) 39 return rewriter.notifyMatchFailure(op, "addf op is not suitable for fma"); 40 41 Value c; 42 arith::MulFOp ab; 43 if ((ab = op.getLhs().getDefiningOp<arith::MulFOp>())) { 44 c = op.getRhs(); 45 } else if ((ab = op.getRhs().getDefiningOp<arith::MulFOp>())) { 46 c = op.getLhs(); 47 } else { 48 return rewriter.notifyMatchFailure(op, "no mulf op"); 49 } 50 51 if (!isValidForFMA(ab)) 52 return rewriter.notifyMatchFailure(ab, "mulf op is not suitable for fma"); 53 54 Value a = ab.getLhs(); 55 Value b = ab.getRhs(); 56 arith::FastMathFlags fmf = op.getFastmath() & ab.getFastmath(); 57 rewriter.replaceOpWithNewOp<math::FmaOp>(op, a, b, c, fmf); 58 return success(); 59 } 60 }; 61 62 struct MathUpliftToFMA final 63 : math::impl::MathUpliftToFMABase<MathUpliftToFMA> { 64 using MathUpliftToFMABase::MathUpliftToFMABase; 65 66 void runOnOperation() override { 67 RewritePatternSet patterns(&getContext()); 68 populateUpliftToFMAPatterns(patterns); 69 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 70 return signalPassFailure(); 71 } 72 }; 73 74 } // namespace 75 76 void mlir::populateUpliftToFMAPatterns(RewritePatternSet &patterns) { 77 patterns.insert<UpliftFma>(patterns.getContext()); 78 } 79