xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp (revision 94cf80d6fd612b36d5a21bf655643e46af2b8802)
1 //===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/IR/Linalg.h"
10 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
11 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12 #include "mlir/IR/Dominance.h"
13 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
14 
15 using namespace mlir;
16 
17 // Determine whether the value is defined to be zero.
18 static bool isDefinedAsZero(Value val) {
19   if (!val)
20     return false;
21 
22   // Check whether val is a constant scalar / vector splat / tensor splat float
23   // or integer zero.
24   if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
25     return true;
26 
27   return TypeSwitch<Operation *, bool>(val.getDefiningOp())
28       .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
29         return op && op.getInputs().size() == 1 &&
30                isDefinedAsZero(op.getInputs()[0]);
31       })
32       .Default([&](auto) { return false; });
33 }
34 
35 /// Replace a linalg.add with one operand the single user of a contraction,
36 /// which has a zero-filled, "identity-mapped" destination and is dominated by
37 /// the `other` operand, by the contraction with `other` as its dest.
38 ///
39 /// As an example, the following pseudo-code will be rewritten
40 ///   %cst = arith.constant 0.000000e+00
41 ///   %empty = tensor.empty()
42 ///   %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
43 ///   %C = linalg.matmul ins(%A, %B) outs(%zeroed)
44 ///   %empty2 = tensor.empty()
45 ///   %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
46 ///   %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
47 ///   %out = linalg.add ins(%C, %F) outs(%empty)
48 /// to:
49 ///   %cst = arith.constant 0.000000e+00
50 ///   %empty = tensor.empty()
51 ///   %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
52 ///   %C = linalg.matmul ins(%A, %B) outs(%zeroed)
53 ///   %out = linalg.matmul ins(%D, %E) outs(%C)
54 ///
55 struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
56   using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
57 
58   LogicalResult matchAndRewrite(linalg::AddOp addOp,
59                                 PatternRewriter &rewriter) const override {
60     // For now, pattern only applies on tensor types (memref support is TODO).
61     if (!addOp.hasPureTensorSemantics())
62       return failure();
63 
64     Value dominatingOperand = nullptr;
65     linalg::LinalgOp dominatedOp = nullptr;
66     { // We will forget about which operand was left or right after this block.
67       Value lhs = addOp.getInputs()[0];
68       Value rhs = addOp.getInputs()[1];
69 
70       // Can only put one of addOp's operands in the dest/out arg of the other's
71       // defining op based on suitable dominance.
72       // TODO: Can be generalized to move ops around as long as that still
73       //       respects use-def chains and doesn't affect side-effects.
74       if (auto rhsOp = rhs.getDefiningOp<linalg::LinalgOp>()) {
75         DominanceInfo domInfo(rhsOp);
76         if (domInfo.properlyDominates(lhs, rhsOp)) {
77           dominatingOperand = lhs;
78           dominatedOp = rhsOp;
79         }
80       }
81       if (auto lhsOp = lhs.getDefiningOp<linalg::LinalgOp>()) {
82         DominanceInfo domInfo(lhsOp);
83         if (domInfo.properlyDominates(rhs, lhsOp)) {
84           dominatingOperand = rhs;
85           dominatedOp = lhsOp;
86         }
87       }
88       if (!dominatingOperand || !dominatedOp)
89         return failure();
90       // NB: As linalg.add's generalisation ignores the out argument in its
91       //     region there is no need to perform checks on addOp's out argument.
92     }
93 
94     // When dominated op is a contraction we know it accumulates on its out arg.
95     // E.g., AddOp is not a contraction and hence ignores its out arg's value.
96     // TODO: Generalize check to also pass in case of other LinalgOps that
97     //       accumulate on their out arg but are not (binary) contraction ops.
98     auto dominatedDestOp =
99         dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
100     if (dominatedOp->getNumResults() != 1 ||
101         !linalg::isaContractionOpInterface(dominatedOp) ||
102         (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
103       return rewriter.notifyMatchFailure(
104           dominatedOp, "expected dominated op to be single-result "
105                        "destination-passing contraction");
106 
107     // To change the contraction's result, `addOp` must be its only user.
108     if (!dominatedOp->getResult(0).hasOneUse())
109       return rewriter.notifyMatchFailure(
110           dominatedOp,
111           "expected linalg.add to be single user of contraction's result");
112 
113     // As `dominatedOp` was already accumulating on its out argument, it is only
114     // safe to no longer use its current out arg when it is the additive ident.
115     auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
116     if (!isDefinedAsZero(destOperand->get()))
117       return rewriter.notifyMatchFailure(
118           dominatedOp, "expected dominated op's dest to be additive zero");
119     // TODO: If the other op is a contraction and has additive ident as dest, we
120     // can swap the dests and achieve the proper sum, given suitable dominance.
121 
122     // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
123     // Hence, we can only substitute `dominatingOperand` for the dest of the
124     // contraction when dest's indexing_map corresponds to an identity map
125     // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
126     SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
127     int prevDimPos = -1;
128     for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
129       auto dim = dyn_cast<AffineDimExpr>(expr);
130       if (!dim || prevDimPos > static_cast<int>(dim.getPosition()))
131         return rewriter.notifyMatchFailure(
132             dominatedOp, "expected index_map for contraction's dest to be an "
133                          "ordered projection");
134       prevDimPos = dim.getPosition();
135     }
136 
137     // Replace the additive-ident, i.e. zero, out arg of the dominated op by the
138     // dominating summand. This makes the dominated op's result the sum of both
139     // of addOp's arguments - therefore we replace addOp and it uses by it.
140     rewriter.modifyOpInPlace(
141         dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
142     rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
143     return success();
144   }
145 };
146 
147 void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) {
148   // Replace linalg.add when destination passing suffices for achieving the sum.
149   patterns.add<FoldAddIntoDest>(patterns.getContext());
150 }
151