xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp (revision 94cf80d6fd612b36d5a21bf655643e46af2b8802)
1*94cf80d6SRolf Morel //===- FoldAddIntoDest.cpp ---------------------------------------*- C++-*-===//
2*94cf80d6SRolf Morel //
3*94cf80d6SRolf Morel // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*94cf80d6SRolf Morel // See https://llvm.org/LICENSE.txt for license information.
5*94cf80d6SRolf Morel // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*94cf80d6SRolf Morel //
7*94cf80d6SRolf Morel //===----------------------------------------------------------------------===//
8*94cf80d6SRolf Morel 
9*94cf80d6SRolf Morel #include "mlir/Dialect/Linalg/IR/Linalg.h"
10*94cf80d6SRolf Morel #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
11*94cf80d6SRolf Morel #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12*94cf80d6SRolf Morel #include "mlir/IR/Dominance.h"
13*94cf80d6SRolf Morel #include "mlir/Interfaces/DestinationStyleOpInterface.h"
14*94cf80d6SRolf Morel 
15*94cf80d6SRolf Morel using namespace mlir;
16*94cf80d6SRolf Morel 
17*94cf80d6SRolf Morel // Determine whether the value is defined to be zero.
18*94cf80d6SRolf Morel static bool isDefinedAsZero(Value val) {
19*94cf80d6SRolf Morel   if (!val)
20*94cf80d6SRolf Morel     return false;
21*94cf80d6SRolf Morel 
22*94cf80d6SRolf Morel   // Check whether val is a constant scalar / vector splat / tensor splat float
23*94cf80d6SRolf Morel   // or integer zero.
24*94cf80d6SRolf Morel   if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
25*94cf80d6SRolf Morel     return true;
26*94cf80d6SRolf Morel 
27*94cf80d6SRolf Morel   return TypeSwitch<Operation *, bool>(val.getDefiningOp())
28*94cf80d6SRolf Morel       .Case<linalg::FillOp, linalg::CopyOp>([&](auto op) {
29*94cf80d6SRolf Morel         return op && op.getInputs().size() == 1 &&
30*94cf80d6SRolf Morel                isDefinedAsZero(op.getInputs()[0]);
31*94cf80d6SRolf Morel       })
32*94cf80d6SRolf Morel       .Default([&](auto) { return false; });
33*94cf80d6SRolf Morel }
34*94cf80d6SRolf Morel 
35*94cf80d6SRolf Morel /// Replace a linalg.add with one operand the single user of a contraction,
36*94cf80d6SRolf Morel /// which has a zero-filled, "identity-mapped" destination and is dominated by
37*94cf80d6SRolf Morel /// the `other` operand, by the contraction with `other` as its dest.
38*94cf80d6SRolf Morel ///
39*94cf80d6SRolf Morel /// As an example, the following pseudo-code will be rewritten
40*94cf80d6SRolf Morel ///   %cst = arith.constant 0.000000e+00
41*94cf80d6SRolf Morel ///   %empty = tensor.empty()
42*94cf80d6SRolf Morel ///   %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
43*94cf80d6SRolf Morel ///   %C = linalg.matmul ins(%A, %B) outs(%zeroed)
44*94cf80d6SRolf Morel ///   %empty2 = tensor.empty()
45*94cf80d6SRolf Morel ///   %zeroed2 = linalg.fill ins(%cst : f32) outs(%empty2 : !type) -> !type
46*94cf80d6SRolf Morel ///   %F = linalg.matmul ins(%D, %E) outs(%zeroed2)
47*94cf80d6SRolf Morel ///   %out = linalg.add ins(%C, %F) outs(%empty)
48*94cf80d6SRolf Morel /// to:
49*94cf80d6SRolf Morel ///   %cst = arith.constant 0.000000e+00
50*94cf80d6SRolf Morel ///   %empty = tensor.empty()
51*94cf80d6SRolf Morel ///   %zeroed = linalg.fill ins(%cst : f32) outs(%empty : !type) -> !type
52*94cf80d6SRolf Morel ///   %C = linalg.matmul ins(%A, %B) outs(%zeroed)
53*94cf80d6SRolf Morel ///   %out = linalg.matmul ins(%D, %E) outs(%C)
54*94cf80d6SRolf Morel ///
55*94cf80d6SRolf Morel struct FoldAddIntoDest final : public OpRewritePattern<linalg::AddOp> {
56*94cf80d6SRolf Morel   using OpRewritePattern<linalg::AddOp>::OpRewritePattern;
57*94cf80d6SRolf Morel 
58*94cf80d6SRolf Morel   LogicalResult matchAndRewrite(linalg::AddOp addOp,
59*94cf80d6SRolf Morel                                 PatternRewriter &rewriter) const override {
60*94cf80d6SRolf Morel     // For now, pattern only applies on tensor types (memref support is TODO).
61*94cf80d6SRolf Morel     if (!addOp.hasPureTensorSemantics())
62*94cf80d6SRolf Morel       return failure();
63*94cf80d6SRolf Morel 
64*94cf80d6SRolf Morel     Value dominatingOperand = nullptr;
65*94cf80d6SRolf Morel     linalg::LinalgOp dominatedOp = nullptr;
66*94cf80d6SRolf Morel     { // We will forget about which operand was left or right after this block.
67*94cf80d6SRolf Morel       Value lhs = addOp.getInputs()[0];
68*94cf80d6SRolf Morel       Value rhs = addOp.getInputs()[1];
69*94cf80d6SRolf Morel 
70*94cf80d6SRolf Morel       // Can only put one of addOp's operands in the dest/out arg of the other's
71*94cf80d6SRolf Morel       // defining op based on suitable dominance.
72*94cf80d6SRolf Morel       // TODO: Can be generalized to move ops around as long as that still
73*94cf80d6SRolf Morel       //       respects use-def chains and doesn't affect side-effects.
74*94cf80d6SRolf Morel       if (auto rhsOp = rhs.getDefiningOp<linalg::LinalgOp>()) {
75*94cf80d6SRolf Morel         DominanceInfo domInfo(rhsOp);
76*94cf80d6SRolf Morel         if (domInfo.properlyDominates(lhs, rhsOp)) {
77*94cf80d6SRolf Morel           dominatingOperand = lhs;
78*94cf80d6SRolf Morel           dominatedOp = rhsOp;
79*94cf80d6SRolf Morel         }
80*94cf80d6SRolf Morel       }
81*94cf80d6SRolf Morel       if (auto lhsOp = lhs.getDefiningOp<linalg::LinalgOp>()) {
82*94cf80d6SRolf Morel         DominanceInfo domInfo(lhsOp);
83*94cf80d6SRolf Morel         if (domInfo.properlyDominates(rhs, lhsOp)) {
84*94cf80d6SRolf Morel           dominatingOperand = rhs;
85*94cf80d6SRolf Morel           dominatedOp = lhsOp;
86*94cf80d6SRolf Morel         }
87*94cf80d6SRolf Morel       }
88*94cf80d6SRolf Morel       if (!dominatingOperand || !dominatedOp)
89*94cf80d6SRolf Morel         return failure();
90*94cf80d6SRolf Morel       // NB: As linalg.add's generalisation ignores the out argument in its
91*94cf80d6SRolf Morel       //     region there is no need to perform checks on addOp's out argument.
92*94cf80d6SRolf Morel     }
93*94cf80d6SRolf Morel 
94*94cf80d6SRolf Morel     // When dominated op is a contraction we know it accumulates on its out arg.
95*94cf80d6SRolf Morel     // E.g., AddOp is not a contraction and hence ignores its out arg's value.
96*94cf80d6SRolf Morel     // TODO: Generalize check to also pass in case of other LinalgOps that
97*94cf80d6SRolf Morel     //       accumulate on their out arg but are not (binary) contraction ops.
98*94cf80d6SRolf Morel     auto dominatedDestOp =
99*94cf80d6SRolf Morel         dyn_cast<DestinationStyleOpInterface>((Operation *)dominatedOp);
100*94cf80d6SRolf Morel     if (dominatedOp->getNumResults() != 1 ||
101*94cf80d6SRolf Morel         !linalg::isaContractionOpInterface(dominatedOp) ||
102*94cf80d6SRolf Morel         (!dominatedDestOp || dominatedDestOp.getNumDpsInits() != 1))
103*94cf80d6SRolf Morel       return rewriter.notifyMatchFailure(
104*94cf80d6SRolf Morel           dominatedOp, "expected dominated op to be single-result "
105*94cf80d6SRolf Morel                        "destination-passing contraction");
106*94cf80d6SRolf Morel 
107*94cf80d6SRolf Morel     // To change the contraction's result, `addOp` must be its only user.
108*94cf80d6SRolf Morel     if (!dominatedOp->getResult(0).hasOneUse())
109*94cf80d6SRolf Morel       return rewriter.notifyMatchFailure(
110*94cf80d6SRolf Morel           dominatedOp,
111*94cf80d6SRolf Morel           "expected linalg.add to be single user of contraction's result");
112*94cf80d6SRolf Morel 
113*94cf80d6SRolf Morel     // As `dominatedOp` was already accumulating on its out argument, it is only
114*94cf80d6SRolf Morel     // safe to no longer use its current out arg when it is the additive ident.
115*94cf80d6SRolf Morel     auto *destOperand = dominatedDestOp.getDpsInitOperand(0);
116*94cf80d6SRolf Morel     if (!isDefinedAsZero(destOperand->get()))
117*94cf80d6SRolf Morel       return rewriter.notifyMatchFailure(
118*94cf80d6SRolf Morel           dominatedOp, "expected dominated op's dest to be additive zero");
119*94cf80d6SRolf Morel     // TODO: If the other op is a contraction and has additive ident as dest, we
120*94cf80d6SRolf Morel     // can swap the dests and achieve the proper sum, given suitable dominance.
121*94cf80d6SRolf Morel 
122*94cf80d6SRolf Morel     // As an operand to `addOp`, `dominatingOperand` has an identity affine_map.
123*94cf80d6SRolf Morel     // Hence, we can only substitute `dominatingOperand` for the dest of the
124*94cf80d6SRolf Morel     // contraction when dest's indexing_map corresponds to an identity map
125*94cf80d6SRolf Morel     // w.r.t. just the dimensions of dest, i.e. is an ordered projection.
126*94cf80d6SRolf Morel     SmallVector<AffineMap> indexMaps = dominatedOp.getIndexingMapsArray();
127*94cf80d6SRolf Morel     int prevDimPos = -1;
128*94cf80d6SRolf Morel     for (auto expr : indexMaps[destOperand->getOperandNumber()].getResults()) {
129*94cf80d6SRolf Morel       auto dim = dyn_cast<AffineDimExpr>(expr);
130*94cf80d6SRolf Morel       if (!dim || prevDimPos > static_cast<int>(dim.getPosition()))
131*94cf80d6SRolf Morel         return rewriter.notifyMatchFailure(
132*94cf80d6SRolf Morel             dominatedOp, "expected index_map for contraction's dest to be an "
133*94cf80d6SRolf Morel                          "ordered projection");
134*94cf80d6SRolf Morel       prevDimPos = dim.getPosition();
135*94cf80d6SRolf Morel     }
136*94cf80d6SRolf Morel 
137*94cf80d6SRolf Morel     // Replace the additive-ident, i.e. zero, out arg of the dominated op by the
138*94cf80d6SRolf Morel     // dominating summand. This makes the dominated op's result the sum of both
139*94cf80d6SRolf Morel     // of addOp's arguments - therefore we replace addOp and it uses by it.
140*94cf80d6SRolf Morel     rewriter.modifyOpInPlace(
141*94cf80d6SRolf Morel         dominatedOp, [&]() { dominatedOp->setOperand(2, dominatingOperand); });
142*94cf80d6SRolf Morel     rewriter.replaceAllOpUsesWith(addOp, dominatedOp->getResult(0));
143*94cf80d6SRolf Morel     return success();
144*94cf80d6SRolf Morel   }
145*94cf80d6SRolf Morel };
146*94cf80d6SRolf Morel 
147*94cf80d6SRolf Morel void linalg::populateFoldAddIntoDestPatterns(RewritePatternSet &patterns) {
148*94cf80d6SRolf Morel   // Replace linalg.add when destination passing suffices for achieving the sum.
149*94cf80d6SRolf Morel   patterns.add<FoldAddIntoDest>(patterns.getContext());
150*94cf80d6SRolf Morel }
151