1 //===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/IR/Linalg.h" 10 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" 11 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 12 #include "mlir/IR/Dominance.h" 13 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 14 15 using namespace mlir; 16 17 // Determine whether the value is defined to be zero. 18 static bool isDefinedAsZero(Value val) { 19 if (!val) 20 return false; 21 22 // Check whether val is a constant scalar / vector splat / tensor splat float 23 // or integer zero. 24 if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero())) 25 return true; 26 27 return TypeSwitch<Operation *, bool>(val.getDefiningOp()) 28 .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) { 29 return op && op.getInputs().size() == 1 && 30 isDefinedAsZero(op.getInputs()[0]); 31 }) 32 .Default([&](auto) { return false; }); 33 } 34 35 /// Replace a linalg.add with one operand the single user of a contraction, 36 /// which has a zero-filled, "identity-mapped" destination and is dominated by 37 /// the `other` operand, by the contraction with `other` as its dest. 38 /// 39 /// As an example, the following pseudo-code will be rewritten 40 /// %cst = arith.constant 0.000000e+00 41 /// %empty = tensor.empty() 42 /// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type 43 /// %C = linalg.matmul ins(%A, %B) outs(%zeroed) 44 /// %empty2 = tensor.empty() 45 /// %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type 46 /// %F = linalg.matmul ins(%D, %E) outs(%zeroed2) 47 /// %out = linalg.add ins(%C, %F) outs(%empty) 48 /// to: 49 /// %cst = arith.constant 0.000000e+00 50 /// %empty = tensor.empty() 51 /// %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type 52 /// %C = linalg.matmul ins(%A, %B) outs(%zeroed) 53 /// %out = linalg.matmul ins(%D, %E) outs(%C) 54 /// 55 struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> { 56 using OpRewritePattern<linalg::AddOp>::OpRewritePattern; 57 58 LogicalResult matchAndRewrite(linalg::AddOp addOp, 59 PatternRewriter &rewriter) const override { 60 // For now, pattern only applies on tensor types (memref support is TODO). 61 if (!addOp.hasPureTensorSemantics()) 62 return failure(); 63 64 Value dominatingOperand = nullptr; 65 linalg::LinalgOp dominatedOp = nullptr; 66 { // We will forget about which operand was left or right after this block. 67 Value lhs = addOp.getInputs()[0]; 68 Value rhs = addOp.getInputs()[1]; 69 70 // Can only put one of addOp's operands in the dest/out arg of the other's 71 // defining op based on suitable dominance. 72 // TODO: Can be generalized to move ops around as long as that still 73 // respects use-def chains and doesn't affect side-effects. 74 if (auto rhsOp = rhs.getDefiningOp<linalg::LinalgOp>()) { 75 DominanceInfo domInfo(rhsOp); 76 if (domInfo.properlyDominates(lhs, rhsOp)) { 77 dominatingOperand = lhs; 78 dominatedOp = rhsOp; 79 } 80 } 81 if (auto lhsOp = lhs.getDefiningOp<linalg::LinalgOp>()) { 82 DominanceInfo domInfo(lhsOp); 83 if (domInfo.properlyDominates(rhs, lhsOp)) { 84 dominatingOperand = rhs; 85 dominatedOp = lhsOp; 86 } 87 } 88 if (!dominatingOperand || !dominatedOp) 89 return failure(); 90 // NB: As linalg.add's generalisation ignores the out argument in its 91 // region there is no need to perform checks on addOp's out argument. 92 } 93 94 // When dominated op is a contraction we know it accumulates on its out arg. 95 // E.g., AddOp is not a contraction and hence ignores its out arg's value. 96 // TODO: Generalize check to also pass in case of other LinalgOps that 97 // accumulate on their out arg but are not (binary) contraction ops. 98 auto dominatedDestOp = 99 dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp); 100 if (dominatedOp->getNumResults() != 1 || 101 !linalg::isaContractionOpInterface(dominatedOp) || 102 (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1)) 103 return rewriter.notifyMatchFailure( 104 dominatedOp, "expected dominated op to be single-result " 105 "destination-passing contraction"); 106 107 // To change the contraction's result, `addOp` must be its only user. 108 if (!dominatedOp->getResult(0).hasOneUse()) 109 return rewriter.notifyMatchFailure( 110 dominatedOp, 111 "expected linalg.add to be single user of contraction's result"); 112 113 // As `dominatedOp` was already accumulating on its out argument, it is only 114 // safe to no longer use its current out arg when it is the additive ident. 115 auto *destOperand = dominatedDestOp.getDpsInitOperand(0); 116 if (!isDefinedAsZero(destOperand->get())) 117 return rewriter.notifyMatchFailure( 118 dominatedOp, "expected dominated op's dest to be additive zero"); 119 // TODO: If the other op is a contraction and has additive ident as dest, we 120 // can swap the dests and achieve the proper sum, given suitable dominance. 121 122 // As an operand to `addOp`, `dominatingOperand` has an identity affine_map. 123 // Hence, we can only substitute `dominatingOperand` for the dest of the 124 // contraction when dest's indexing_map corresponds to an identity map 125 // w.r.t. just the dimensions of dest, i.e. is an ordered projection. 126 SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray(); 127 int prevDimPos = -1; 128 for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) { 129 auto dim = dyn_cast<AffineDimExpr>(expr); 130 if (!dim || prevDimPos > static_cast<int>(dim.getPosition())) 131 return rewriter.notifyMatchFailure( 132 dominatedOp, "expected index_map for contraction's dest to be an " 133 "ordered projection"); 134 prevDimPos = dim.getPosition(); 135 } 136 137 // Replace the additive-ident, i.e. zero, out arg of the dominated op by the 138 // dominating summand. This makes the dominated op's result the sum of both 139 // of addOp's arguments - therefore we replace addOp and it uses by it. 140 rewriter.modifyOpInPlace( 141 dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); }); 142 rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0)); 143 return success(); 144 } 145 }; 146 147 void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) { 148 // Replace linalg.add when destination passing suffices for achieving the sum. 149 patterns.add<FoldAddIntoDest>(patterns.getContext()); 150 } 151