xref: /llvm-project/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp (revision d9111f19d2ea53d8ce105b3d09425394ccf37969)
1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
15 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/Dominance.h"
18 #include "mlir/Interfaces/SubsetOpInterface.h"
19 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 namespace bufferization {
23 #define GEN_PASS_DEF_EMPTYTENSORELIMINATION
24 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
25 } // namespace bufferization
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::bufferization;
30 
31 /// Return true if all `neededValues` are in scope at the given
32 /// `insertionPoint`.
33 static bool
34 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
35                                    Operation *insertionPoint,
36                                    const SmallVector<Value> &neededValues) {
37   for (Value val : neededValues) {
38     if (auto bbArg = dyn_cast<BlockArgument>(val)) {
39       Block *owner = bbArg.getOwner();
40       if (!owner->findAncestorOpInBlock(*insertionPoint))
41         return false;
42     } else {
43       auto opResult = cast<OpResult>(val);
44       if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
45         return false;
46     }
47   }
48   return true;
49 }
50 
51 /// Find a valid insertion point for a replacement of `emptyTensorOp`'s
52 /// use of `user` operation, assuming that the replacement may use any
53 /// value from `neededValues`.
54 static Operation *
55 findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
56                         const SmallVector<Value> &neededValues) {
57   DominanceInfo domInfo;
58   Operation *candidateInsertionPoint = emptyTensorOp;
59 
60   // Gather all possible insertion points: the location of
61   // `candidateInsertionPoint` and right after the definition of each value in
62   // `neededValues`.
63   SmallVector<Operation *> insertionPointCandidates;
64   insertionPointCandidates.push_back(candidateInsertionPoint);
65   for (Value val : neededValues) {
66     // Note: The anchor op is using all of `neededValues`, so:
67     // * in case of a block argument: There must be at least one op in the block
68     //                                (the anchor op or one of its parents).
69     // * in case of an OpResult: There must be at least one op right after the
70     //                           defining op (the anchor op or one of its
71     //                           parents).
72     if (auto bbArg = dyn_cast<BlockArgument>(val)) {
73       insertionPointCandidates.push_back(
74           &bbArg.getOwner()->getOperations().front());
75     } else {
76       insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
77     }
78   }
79 
80   // Select first matching insertion point.
81   for (Operation *insertionPoint : insertionPointCandidates) {
82     // Check if all needed values are in scope.
83     if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
84                                             neededValues))
85       continue;
86     // Check if the insertion point is before the use to be replaced.
87     if (!domInfo.dominates(insertionPoint, user))
88       continue;
89     return insertionPoint;
90   }
91 
92   // No suitable insertion point was found.
93   return nullptr;
94 }
95 
96 Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
97                                                  SubsetInsertionOpInterface op,
98                                                  tensor::EmptyOp emptyTensorOp,
99                                                  Operation *user) {
100 
101   mlir::OpBuilder::InsertionGuard guard(rewriter);
102   // All values that are needed to create the replacement op.
103   SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
104   // Find a suitable insertion point. If no suitable insertion point
105   // for the replacement can be found, return an empty value to skip
106   // this replacement.
107   Operation *insertionPoint =
108       findValidInsertionPoint(emptyTensorOp, user, neededValues);
109   if (!insertionPoint)
110     return {};
111 
112   rewriter.setInsertionPoint(insertionPoint);
113   Value replacement =
114       op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
115   return replacement;
116 }
117 
118 LogicalResult mlir::bufferization::eliminateEmptyTensors(
119     RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
120     ControlBuildSubsetExtractionFn subsetsExtractionFn) {
121   OpBuilder::InsertionGuard g(rewriter);
122   llvm::DenseSet<OpOperand *> visitedOpOperands;
123   op->walk([&](SubsetInsertionOpInterface op) {
124     visitedOpOperands.clear();
125     OpOperand &source = op.getSourceOperand();
126     // Skip operands that do not bufferize inplace. "tensor.empty" could still
127     // be replaced, but the transformation may not be beneficial.
128     if (!state.isInPlace(source))
129       return WalkResult::skip();
130 
131     // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
132     // equivalent tensors. I.e., stop when there are ops such as extract_slice
133     // on the path.
134     TraversalConfig config;
135     config.followEquivalentOnly = true;
136     config.alwaysIncludeLeaves = false;
137     // Replace only if the types match or are static <-> dynamic casts. We do
138     // not support slices or reshapes.
139     // TODO: This could be extended to support IR such as:
140     // %0 = tensor.empty() : tensor<128xf32>
141     // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>)
142     // %2 = tensor.expand_shape %1 ...
143     // %3 = tensor.insert_slice %2 into ...
144     config.followSameTypeOrCastsOnly = true;
145     SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
146         &source, /*condition=*/
147         [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
148         &visitedOpOperands);
149 
150     for (Value v : emptyTensors) {
151       auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
152       assert(emptyTensorOp && "expected tensor.empty op");
153       // Find the use to be replaced from the use-def chain.
154       auto iter = llvm::find_if(
155           visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
156             return llvm::count(emptyTensorOp->getUses(), *opOperand);
157           });
158 
159       assert(iter != visitedOpOperands.end() && "could not find use");
160       OpOperand *useToBeReplaced = *iter;
161       Operation *user = useToBeReplaced->getOwner();
162       auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
163       if (!replacement)
164         continue;
165       if (emptyTensorOp == replacement.getDefiningOp())
166         continue;
167       if (replacement.getType() != v.getType()) {
168         if (cast<ShapedType>(replacement.getType()).getElementType() !=
169             cast<ShapedType>(v.getType()).getElementType())
170           continue;
171         rewriter.setInsertionPointAfterValue(replacement);
172         replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
173                                                       replacement);
174       }
175       // Replace the specific use of the tensor::EmptyOp.
176       rewriter.modifyOpInPlace(user, [&]() {
177         user->setOperand(useToBeReplaced->getOperandNumber(), replacement);
178       });
179       state.resetCache();
180     }
181 
182     return WalkResult::advance();
183   });
184 
185   return success();
186 }
187 
188 namespace {
189 struct EmptyTensorElimination
190     : public bufferization::impl::EmptyTensorEliminationBase<
191           EmptyTensorElimination> {
192   EmptyTensorElimination() = default;
193 
194   void runOnOperation() override;
195 
196   void getDependentDialects(DialectRegistry &registry) const override {
197     registry
198         .insert<bufferization::BufferizationDialect, tensor::TensorDialect>();
199   }
200 };
201 } // namespace
202 
203 LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
204                                                          Operation *op) {
205   auto moduleOp = dyn_cast<ModuleOp>(op);
206   OneShotBufferizationOptions options;
207   options.allowReturnAllocsFromLoops = true;
208   if (moduleOp)
209     options.bufferizeFunctionBoundaries = true;
210   OneShotAnalysisState state(op, options);
211   if (moduleOp) {
212     // Module analysis takes into account function boundaries.
213     if (failed(analyzeModuleOp(moduleOp, state)))
214       return failure();
215   } else {
216     // Regular One-Shot Bufferize ignores func.func block arguments, func.call,
217     // func.return.
218     if (failed(analyzeOp(op, state)))
219       return failure();
220   }
221 
222   return bufferization::eliminateEmptyTensors(rewriter, op, state);
223 }
224 
225 void EmptyTensorElimination::runOnOperation() {
226   IRRewriter rewriter(getOperation()->getContext());
227   if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation())))
228     signalPassFailure();
229 }
230 
231 std::unique_ptr<Pass> mlir::bufferization::createEmptyTensorEliminationPass() {
232   return std::make_unique<EmptyTensorElimination>();
233 }
234