xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp (revision d9111f19d2ea53d8ce105b3d09425394ccf37969)
1180f9ef8SMatthias Springer //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2180f9ef8SMatthias Springer //
3180f9ef8SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4180f9ef8SMatthias Springer // See https://llvm.org/LICENSE.txt for license information.
5180f9ef8SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6180f9ef8SMatthias Springer //
7180f9ef8SMatthias Springer //===----------------------------------------------------------------------===//
8180f9ef8SMatthias Springer 
9180f9ef8SMatthias Springer #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10180f9ef8SMatthias Springer 
11180f9ef8SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12180f9ef8SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13180f9ef8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14180f9ef8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
15180f9ef8SMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h"
16180f9ef8SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h"
17180f9ef8SMatthias Springer 
18180f9ef8SMatthias Springer using namespace mlir;
19180f9ef8SMatthias Springer using namespace mlir::bufferization;
20180f9ef8SMatthias Springer using namespace mlir::linalg;
21180f9ef8SMatthias Springer 
22180f9ef8SMatthias Springer /// Get an output operand that matches the given input operand and can be used
23180f9ef8SMatthias Springer /// to eliminate a tensor.empty op.
24180f9ef8SMatthias Springer static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
250b2197b0SMatthias Springer   for (OpOperand &operand : op.getDpsInitsMutable()) {
26180f9ef8SMatthias Springer     // Operand must be unused.
270b2197b0SMatthias Springer     if (op.payloadUsesValueFromOperand(&operand))
28180f9ef8SMatthias Springer       continue;
29180f9ef8SMatthias Springer     // Types must match.
300b2197b0SMatthias Springer     if (operand.get().getType() != in->get().getType())
31180f9ef8SMatthias Springer       continue;
32180f9ef8SMatthias Springer     // Indexing maps must match.
330b2197b0SMatthias Springer     if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in))
34180f9ef8SMatthias Springer       continue;
350b2197b0SMatthias Springer     return &operand;
36180f9ef8SMatthias Springer   }
37180f9ef8SMatthias Springer   return nullptr;
38180f9ef8SMatthias Springer }
39180f9ef8SMatthias Springer 
40180f9ef8SMatthias Springer LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
41180f9ef8SMatthias Springer     RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
42180f9ef8SMatthias Springer   OpBuilder::InsertionGuard g(rewriter);
43180f9ef8SMatthias Springer   DominanceInfo domInfo;
44180f9ef8SMatthias Springer 
45180f9ef8SMatthias Springer   op->walk([&](LinalgOp op) {
46180f9ef8SMatthias Springer     // Only ops with all "parallel" iterator types are supported.
47180f9ef8SMatthias Springer     if (op.getNumParallelLoops() != op.getNumLoops())
48180f9ef8SMatthias Springer       return WalkResult::skip();
49180f9ef8SMatthias Springer 
50180f9ef8SMatthias Springer     for (OpOperand *in : op.getDpsInputOperands()) {
51180f9ef8SMatthias Springer       // Skip non-tensor operands.
52a5757c5bSChristian Sigg       if (!isa<RankedTensorType>(in->get().getType()))
53180f9ef8SMatthias Springer         continue;
54180f9ef8SMatthias Springer 
55180f9ef8SMatthias Springer       // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
56180f9ef8SMatthias Springer       // equivalent tensors. I.e., stop when there are ops such as extract_slice
57180f9ef8SMatthias Springer       // on the path.
58180f9ef8SMatthias Springer       TraversalConfig config;
59180f9ef8SMatthias Springer       config.followEquivalentOnly = true;
60180f9ef8SMatthias Springer       config.alwaysIncludeLeaves = false;
61180f9ef8SMatthias Springer       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
62*d9111f19SAmir Bishara           in, /*condition=*/
636b65d79fSSpenser Bauman           [&](Value val) {
646b65d79fSSpenser Bauman             return val.getDefiningOp<tensor::EmptyOp>() &&
656b65d79fSSpenser Bauman                    val.getType() == in->get().getType();
666b65d79fSSpenser Bauman           },
67180f9ef8SMatthias Springer           config);
68180f9ef8SMatthias Springer       if (emptyTensors.empty())
69180f9ef8SMatthias Springer         continue;
70180f9ef8SMatthias Springer 
71180f9ef8SMatthias Springer       // Find matching out operand.
72180f9ef8SMatthias Springer       OpOperand *out = getUnusedOutOperand(op, in);
73180f9ef8SMatthias Springer       if (!out)
74180f9ef8SMatthias Springer         continue;
75180f9ef8SMatthias Springer 
76180f9ef8SMatthias Springer       // Check if this transform would violate dominance.
77180f9ef8SMatthias Springer       if (!llvm::all_of(emptyTensors, [&](Value v) {
78180f9ef8SMatthias Springer             return domInfo.properlyDominates(out->get(), v.getDefiningOp());
79180f9ef8SMatthias Springer           }))
80180f9ef8SMatthias Springer         continue;
81180f9ef8SMatthias Springer 
82180f9ef8SMatthias Springer       // Replace all uses of the tensor.empty, but do not delete it yet. It will
83180f9ef8SMatthias Springer       // fold away later (to not invalidate DominanceInfo).
84180f9ef8SMatthias Springer       for (Value v : emptyTensors) {
85180f9ef8SMatthias Springer         assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
86180f9ef8SMatthias Springer         rewriter.replaceAllUsesWith(v, out->get());
87180f9ef8SMatthias Springer       }
88180f9ef8SMatthias Springer 
89180f9ef8SMatthias Springer       // Turn the "in" into an "out".
905fcf907bSMatthias Springer       rewriter.modifyOpInPlace(op, [&]() {
91180f9ef8SMatthias Springer         out->set(in->get());
92180f9ef8SMatthias Springer         // The original "in" could be removed entirely here (because it will no
93180f9ef8SMatthias Springer         // longer have any uses in the payload), but we delegate this to
94180f9ef8SMatthias Springer         // existing cleanup patterns that remove unused operands.
95180f9ef8SMatthias Springer         in->set(emptyTensors.front());
96180f9ef8SMatthias Springer         BlockArgument outArg = op.getMatchingBlockArgument(out);
97180f9ef8SMatthias Springer         assert(outArg.getUses().empty() && "expected that out has no uses");
98180f9ef8SMatthias Springer         BlockArgument inArg = op.getMatchingBlockArgument(in);
99180f9ef8SMatthias Springer         rewriter.replaceAllUsesWith(inArg, outArg);
100180f9ef8SMatthias Springer         assert(!op.payloadUsesValueFromOperand(in) &&
101180f9ef8SMatthias Springer                "expected that the in operand is now unused");
102180f9ef8SMatthias Springer       });
103180f9ef8SMatthias Springer 
104180f9ef8SMatthias Springer       state.resetCache();
105180f9ef8SMatthias Springer     }
106180f9ef8SMatthias Springer 
107180f9ef8SMatthias Springer     return WalkResult::advance();
108180f9ef8SMatthias Springer   });
109180f9ef8SMatthias Springer   return success();
110180f9ef8SMatthias Springer }
111