xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp (revision d9111f19d2ea53d8ce105b3d09425394ccf37969)
1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 
18 using namespace mlir;
19 using namespace mlir::bufferization;
20 using namespace mlir::linalg;
21 
22 /// Get an output operand that matches the given input operand and can be used
23 /// to eliminate a tensor.empty op.
24 static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
25   for (OpOperand &operand : op.getDpsInitsMutable()) {
26     // Operand must be unused.
27     if (op.payloadUsesValueFromOperand(&operand))
28       continue;
29     // Types must match.
30     if (operand.get().getType() != in->get().getType())
31       continue;
32     // Indexing maps must match.
33     if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in))
34       continue;
35     return &operand;
36   }
37   return nullptr;
38 }
39 
40 LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
41     RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
42   OpBuilder::InsertionGuard g(rewriter);
43   DominanceInfo domInfo;
44 
45   op->walk([&](LinalgOp op) {
46     // Only ops with all "parallel" iterator types are supported.
47     if (op.getNumParallelLoops() != op.getNumLoops())
48       return WalkResult::skip();
49 
50     for (OpOperand *in : op.getDpsInputOperands()) {
51       // Skip non-tensor operands.
52       if (!isa<RankedTensorType>(in->get().getType()))
53         continue;
54 
55       // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
56       // equivalent tensors. I.e., stop when there are ops such as extract_slice
57       // on the path.
58       TraversalConfig config;
59       config.followEquivalentOnly = true;
60       config.alwaysIncludeLeaves = false;
61       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
62           in, /*condition=*/
63           [&](Value val) {
64             return val.getDefiningOp<tensor::EmptyOp>() &&
65                    val.getType() == in->get().getType();
66           },
67           config);
68       if (emptyTensors.empty())
69         continue;
70 
71       // Find matching out operand.
72       OpOperand *out = getUnusedOutOperand(op, in);
73       if (!out)
74         continue;
75 
76       // Check if this transform would violate dominance.
77       if (!llvm::all_of(emptyTensors, [&](Value v) {
78             return domInfo.properlyDominates(out->get(), v.getDefiningOp());
79           }))
80         continue;
81 
82       // Replace all uses of the tensor.empty, but do not delete it yet. It will
83       // fold away later (to not invalidate DominanceInfo).
84       for (Value v : emptyTensors) {
85         assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
86         rewriter.replaceAllUsesWith(v, out->get());
87       }
88 
89       // Turn the "in" into an "out".
90       rewriter.modifyOpInPlace(op, [&]() {
91         out->set(in->get());
92         // The original "in" could be removed entirely here (because it will no
93         // longer have any uses in the payload), but we delegate this to
94         // existing cleanup patterns that remove unused operands.
95         in->set(emptyTensors.front());
96         BlockArgument outArg = op.getMatchingBlockArgument(out);
97         assert(outArg.getUses().empty() && "expected that out has no uses");
98         BlockArgument inArg = op.getMatchingBlockArgument(in);
99         rewriter.replaceAllUsesWith(inArg, outArg);
100         assert(!op.payloadUsesValueFromOperand(in) &&
101                "expected that the in operand is now unused");
102       });
103 
104       state.resetCache();
105     }
106 
107     return WalkResult::advance();
108   });
109   return success();
110 }
111