1*94cf80d6SRolf Morel //===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===// 2*94cf80d6SRolf Morel // 3*94cf80d6SRolf Morel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*94cf80d6SRolf Morel // See https://llvm.org/LICENSE.txt for license information. 5*94cf80d6SRolf Morel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*94cf80d6SRolf Morel // 7*94cf80d6SRolf Morel //===----------------------------------------------------------------------===// 8*94cf80d6SRolf Morel 9*94cf80d6SRolf Morel #include "mlir/Dialect/Linalg/IR/Linalg.h" 10*94cf80d6SRolf Morel #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" 11*94cf80d6SRolf Morel #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 12*94cf80d6SRolf Morel #include "mlir/IR/Dominance.h" 13*94cf80d6SRolf Morel #include "mlir/Interfaces/DestinationStyleOpInterface.h" 14*94cf80d6SRolf Morel 15*94cf80d6SRolf Morel using namespace mlir; 16*94cf80d6SRolf Morel 17*94cf80d6SRolf Morel // Determine whether the value is defined to be zero. 18*94cf80d6SRolf Morel static bool isDefinedAsZero(Value val) { 19*94cf80d6SRolf Morel if (!val) 20*94cf80d6SRolf Morel return false; 21*94cf80d6SRolf Morel 22*94cf80d6SRolf Morel // Check whether val is a constant scalar / vector splat / tensor splat float 23*94cf80d6SRolf Morel // or integer zero. 24*94cf80d6SRolf Morel if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero())) 25*94cf80d6SRolf Morel return true; 26*94cf80d6SRolf Morel 27*94cf80d6SRolf Morel return TypeSwitch<Operation *, bool>(val.getDefiningOp()) 28*94cf80d6SRolf Morel .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) { 29*94cf80d6SRolf Morel return op && op.getInputs().size() == 1 && 30*94cf80d6SRolf Morel isDefinedAsZero(op.getInputs()[0]); 31*94cf80d6SRolf Morel }) 32*94cf80d6SRolf Morel .Default([&](auto) { return false; }); 33*94cf80d6SRolf Morel } 34*94cf80d6SRolf Morel 35*94cf80d6SRolf Morel /// Replace a linalg.add with one operand the single user of a contraction, 36*94cf80d6SRolf Morel /// which has a zero-filled, "identity-mapped" destination and is dominated by 37*94cf80d6SRolf Morel /// the `other` operand, by the contraction with `other` as its dest. 38*94cf80d6SRolf Morel /// 39*94cf80d6SRolf Morel /// As an example, the following pseudo-code will be rewritten 40*94cf80d6SRolf Morel /// %cst = arith.constant 0.000000e+00 41*94cf80d6SRolf Morel /// %empty = tensor.empty() 42*94cf80d6SRolf Morel /// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type 43*94cf80d6SRolf Morel /// %C = linalg.matmul ins(%A, %B) outs(%zeroed) 44*94cf80d6SRolf Morel /// %empty2 = tensor.empty() 45*94cf80d6SRolf Morel /// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type 46*94cf80d6SRolf Morel /// %F = linalg.matmul ins(%D, %E) outs(%zeroed2) 47*94cf80d6SRolf Morel /// %out = linalg.add ins(%C, %F) outs(%empty) 48*94cf80d6SRolf Morel /// to: 49*94cf80d6SRolf Morel /// %cst = arith.constant 0.000000e+00 50*94cf80d6SRolf Morel /// %empty = tensor.empty() 51*94cf80d6SRolf Morel /// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type 52*94cf80d6SRolf Morel /// %C = linalg.matmul ins(%A, %B) outs(%zeroed) 53*94cf80d6SRolf Morel /// %out = linalg.matmul ins(%D, %E) outs(%C) 54*94cf80d6SRolf Morel /// 55*94cf80d6SRolf Morel struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> { 56*94cf80d6SRolf Morel using OpRewritePattern<linalg::AddOp>::OpRewritePattern; 57*94cf80d6SRolf Morel 58*94cf80d6SRolf Morel LogicalResult matchAndRewrite(linalg::AddOp addOp, 59*94cf80d6SRolf Morel PatternRewriter &rewriter) const override { 60*94cf80d6SRolf Morel // For now, pattern only applies on tensor types (memref support is TODO). 61*94cf80d6SRolf Morel if (!addOp.hasPureTensorSemantics()) 62*94cf80d6SRolf Morel return failure(); 63*94cf80d6SRolf Morel 64*94cf80d6SRolf Morel Value dominatingOperand = nullptr; 65*94cf80d6SRolf Morel linalg::LinalgOp dominatedOp = nullptr; 66*94cf80d6SRolf Morel { // We will forget about which operand was left or right after this block. 67*94cf80d6SRolf Morel Value lhs = addOp.getInputs()[0]; 68*94cf80d6SRolf Morel Value rhs = addOp.getInputs()[1]; 69*94cf80d6SRolf Morel 70*94cf80d6SRolf Morel // Can only put one of addOp's operands in the dest/out arg of the other's 71*94cf80d6SRolf Morel // defining op based on suitable dominance. 72*94cf80d6SRolf Morel // TODO: Can be generalized to move ops around as long as that still 73*94cf80d6SRolf Morel // respects use-def chains and doesn't affect side-effects. 74*94cf80d6SRolf Morel if (auto rhsOp = rhs.getDefiningOp<linalg::LinalgOp>()) { 75*94cf80d6SRolf Morel DominanceInfo domInfo(rhsOp); 76*94cf80d6SRolf Morel if (domInfo.properlyDominates(lhs, rhsOp)) { 77*94cf80d6SRolf Morel dominatingOperand = lhs; 78*94cf80d6SRolf Morel dominatedOp = rhsOp; 79*94cf80d6SRolf Morel } 80*94cf80d6SRolf Morel } 81*94cf80d6SRolf Morel if (auto lhsOp = lhs.getDefiningOp<linalg::LinalgOp>()) { 82*94cf80d6SRolf Morel DominanceInfo domInfo(lhsOp); 83*94cf80d6SRolf Morel if (domInfo.properlyDominates(rhs, lhsOp)) { 84*94cf80d6SRolf Morel dominatingOperand = rhs; 85*94cf80d6SRolf Morel dominatedOp = lhsOp; 86*94cf80d6SRolf Morel } 87*94cf80d6SRolf Morel } 88*94cf80d6SRolf Morel if (!dominatingOperand || !dominatedOp) 89*94cf80d6SRolf Morel return failure(); 90*94cf80d6SRolf Morel // NB: As linalg.add's generalisation ignores the out argument in its 91*94cf80d6SRolf Morel // region there is no need to perform checks on addOp's out argument. 92*94cf80d6SRolf Morel } 93*94cf80d6SRolf Morel 94*94cf80d6SRolf Morel // When dominated op is a contraction we know it accumulates on its out arg. 95*94cf80d6SRolf Morel // E.g., AddOp is not a contraction and hence ignores its out arg's value. 96*94cf80d6SRolf Morel // TODO: Generalize check to also pass in case of other LinalgOps that 97*94cf80d6SRolf Morel // accumulate on their out arg but are not (binary) contraction ops. 98*94cf80d6SRolf Morel auto dominatedDestOp = 99*94cf80d6SRolf Morel dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp); 100*94cf80d6SRolf Morel if (dominatedOp->getNumResults() != 1 || 101*94cf80d6SRolf Morel !linalg::isaContractionOpInterface(dominatedOp) || 102*94cf80d6SRolf Morel (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1)) 103*94cf80d6SRolf Morel return rewriter.notifyMatchFailure( 104*94cf80d6SRolf Morel dominatedOp, "expected dominated op to be single-result " 105*94cf80d6SRolf Morel "destination-passing contraction"); 106*94cf80d6SRolf Morel 107*94cf80d6SRolf Morel // To change the contraction's result, `addOp` must be its only user. 108*94cf80d6SRolf Morel if (!dominatedOp->getResult(0).hasOneUse()) 109*94cf80d6SRolf Morel return rewriter.notifyMatchFailure( 110*94cf80d6SRolf Morel dominatedOp, 111*94cf80d6SRolf Morel "expected linalg.add to be single user of contraction's result"); 112*94cf80d6SRolf Morel 113*94cf80d6SRolf Morel // As `dominatedOp` was already accumulating on its out argument, it is only 114*94cf80d6SRolf Morel // safe to no longer use its current out arg when it is the additive ident. 115*94cf80d6SRolf Morel auto *destOperand = dominatedDestOp.getDpsInitOperand(0); 116*94cf80d6SRolf Morel if (!isDefinedAsZero(destOperand->get())) 117*94cf80d6SRolf Morel return rewriter.notifyMatchFailure( 118*94cf80d6SRolf Morel dominatedOp, "expected dominated op's dest to be additive zero"); 119*94cf80d6SRolf Morel // TODO: If the other op is a contraction and has additive ident as dest, we 120*94cf80d6SRolf Morel // can swap the dests and achieve the proper sum, given suitable dominance. 121*94cf80d6SRolf Morel 122*94cf80d6SRolf Morel // As an operand to `addOp`, `dominatingOperand` has an identity affine_map. 123*94cf80d6SRolf Morel // Hence, we can only substitute `dominatingOperand` for the dest of the 124*94cf80d6SRolf Morel // contraction when dest's indexing_map corresponds to an identity map 125*94cf80d6SRolf Morel // w.r.t. just the dimensions of dest, i.e. is an ordered projection. 126*94cf80d6SRolf Morel SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray(); 127*94cf80d6SRolf Morel int prevDimPos = -1; 128*94cf80d6SRolf Morel for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) { 129*94cf80d6SRolf Morel auto dim = dyn_cast<AffineDimExpr>(expr); 130*94cf80d6SRolf Morel if (!dim || prevDimPos > static_cast<int>(dim.getPosition())) 131*94cf80d6SRolf Morel return rewriter.notifyMatchFailure( 132*94cf80d6SRolf Morel dominatedOp, "expected index_map for contraction's dest to be an " 133*94cf80d6SRolf Morel "ordered projection"); 134*94cf80d6SRolf Morel prevDimPos = dim.getPosition(); 135*94cf80d6SRolf Morel } 136*94cf80d6SRolf Morel 137*94cf80d6SRolf Morel // Replace the additive-ident, i.e. zero, out arg of the dominated op by the 138*94cf80d6SRolf Morel // dominating summand. This makes the dominated op's result the sum of both 139*94cf80d6SRolf Morel // of addOp's arguments - therefore we replace addOp and it uses by it. 140*94cf80d6SRolf Morel rewriter.modifyOpInPlace( 141*94cf80d6SRolf Morel dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); }); 142*94cf80d6SRolf Morel rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0)); 143*94cf80d6SRolf Morel return success(); 144*94cf80d6SRolf Morel } 145*94cf80d6SRolf Morel }; 146*94cf80d6SRolf Morel 147*94cf80d6SRolf Morel void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) { 148*94cf80d6SRolf Morel // Replace linalg.add when destination passing suffices for achieving the sum. 149*94cf80d6SRolf Morel patterns.add<FoldAddIntoDest>(patterns.getContext()); 150*94cf80d6SRolf Morel } 151