1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/Dominance.h" 18 #include "mlir/Interfaces/SubsetOpInterface.h" 19 #include "mlir/Pass/Pass.h" 20 21 namespace mlir { 22 namespace bufferization { 23 #define GEN_PASS_DEF_EMPTYTENSORELIMINATION 24 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" 25 } // namespace bufferization 26 } // namespace mlir 27 28 using namespace mlir; 29 using namespace mlir::bufferization; 30 31 /// Return true if all `neededValues` are in scope at the given 32 /// `insertionPoint`. 33 static bool 34 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, 35 Operation *insertionPoint, 36 const SmallVector<Value> &neededValues) { 37 for (Value val : neededValues) { 38 if (auto bbArg = dyn_cast<BlockArgument>(val)) { 39 Block *owner = bbArg.getOwner(); 40 if (!owner->findAncestorOpInBlock(*insertionPoint)) 41 return false; 42 } else { 43 auto opResult = cast<OpResult>(val); 44 if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint)) 45 return false; 46 } 47 } 48 return true; 49 } 50 51 /// Find a valid insertion point for a replacement of `emptyTensorOp`'s 52 /// use of `user` operation, assuming that the replacement may use any 53 /// value from `neededValues`. 54 static Operation * 55 findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, 56 const SmallVector<Value> &neededValues) { 57 DominanceInfo domInfo; 58 Operation *candidateInsertionPoint = emptyTensorOp; 59 60 // Gather all possible insertion points: the location of 61 // `candidateInsertionPoint` and right after the definition of each value in 62 // `neededValues`. 63 SmallVector<Operation *> insertionPointCandidates; 64 insertionPointCandidates.push_back(candidateInsertionPoint); 65 for (Value val : neededValues) { 66 // Note: The anchor op is using all of `neededValues`, so: 67 // * in case of a block argument: There must be at least one op in the block 68 // (the anchor op or one of its parents). 69 // * in case of an OpResult: There must be at least one op right after the 70 // defining op (the anchor op or one of its 71 // parents). 72 if (auto bbArg = dyn_cast<BlockArgument>(val)) { 73 insertionPointCandidates.push_back( 74 &bbArg.getOwner()->getOperations().front()); 75 } else { 76 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode()); 77 } 78 } 79 80 // Select first matching insertion point. 81 for (Operation *insertionPoint : insertionPointCandidates) { 82 // Check if all needed values are in scope. 83 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, 84 neededValues)) 85 continue; 86 // Check if the insertion point is before the use to be replaced. 87 if (!domInfo.dominates(insertionPoint, user)) 88 continue; 89 return insertionPoint; 90 } 91 92 // No suitable insertion point was found. 93 return nullptr; 94 } 95 96 Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter, 97 SubsetInsertionOpInterface op, 98 tensor::EmptyOp emptyTensorOp, 99 Operation *user) { 100 101 mlir::OpBuilder::InsertionGuard guard(rewriter); 102 // All values that are needed to create the replacement op. 103 SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction(); 104 // Find a suitable insertion point. If no suitable insertion point 105 // for the replacement can be found, return an empty value to skip 106 // this replacement. 107 Operation *insertionPoint = 108 findValidInsertionPoint(emptyTensorOp, user, neededValues); 109 if (!insertionPoint) 110 return {}; 111 112 rewriter.setInsertionPoint(insertionPoint); 113 Value replacement = 114 op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); 115 return replacement; 116 } 117 118 LogicalResult mlir::bufferization::eliminateEmptyTensors( 119 RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, 120 ControlBuildSubsetExtractionFn subsetsExtractionFn) { 121 OpBuilder::InsertionGuard g(rewriter); 122 llvm::DenseSet<OpOperand *> visitedOpOperands; 123 op->walk([&](SubsetInsertionOpInterface op) { 124 visitedOpOperands.clear(); 125 OpOperand &source = op.getSourceOperand(); 126 // Skip operands that do not bufferize inplace. "tensor.empty" could still 127 // be replaced, but the transformation may not be beneficial. 128 if (!state.isInPlace(source)) 129 return WalkResult::skip(); 130 131 // Find tensor.empty ops on the reverse SSA use-def chain. Only follow 132 // equivalent tensors. I.e., stop when there are ops such as extract_slice 133 // on the path. 134 TraversalConfig config; 135 config.followEquivalentOnly = true; 136 config.alwaysIncludeLeaves = false; 137 // Replace only if the types match or are static <-> dynamic casts. We do 138 // not support slices or reshapes. 139 // TODO: This could be extended to support IR such as: 140 // %0 = tensor.empty() : tensor<128xf32> 141 // %1 = "some_op"(%0) : (tensor<128xf32>) -> (tensor<128xf32>) 142 // %2 = tensor.expand_shape %1 ... 143 // %3 = tensor.insert_slice %2 into ... 144 config.followSameTypeOrCastsOnly = true; 145 SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( 146 &source, /*condition=*/ 147 [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config, 148 &visitedOpOperands); 149 150 for (Value v : emptyTensors) { 151 auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>(); 152 assert(emptyTensorOp && "expected tensor.empty op"); 153 // Find the use to be replaced from the use-def chain. 154 auto iter = llvm::find_if( 155 visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) { 156 return llvm::count(emptyTensorOp->getUses(), *opOperand); 157 }); 158 159 assert(iter != visitedOpOperands.end() && "could not find use"); 160 OpOperand *useToBeReplaced = *iter; 161 Operation *user = useToBeReplaced->getOwner(); 162 auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user); 163 if (!replacement) 164 continue; 165 if (emptyTensorOp == replacement.getDefiningOp()) 166 continue; 167 if (replacement.getType() != v.getType()) { 168 if (cast<ShapedType>(replacement.getType()).getElementType() != 169 cast<ShapedType>(v.getType()).getElementType()) 170 continue; 171 rewriter.setInsertionPointAfterValue(replacement); 172 replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(), 173 replacement); 174 } 175 // Replace the specific use of the tensor::EmptyOp. 176 rewriter.modifyOpInPlace(user, [&]() { 177 user->setOperand(useToBeReplaced->getOperandNumber(), replacement); 178 }); 179 state.resetCache(); 180 } 181 182 return WalkResult::advance(); 183 }); 184 185 return success(); 186 } 187 188 namespace { 189 struct EmptyTensorElimination 190 : public bufferization::impl::EmptyTensorEliminationBase< 191 EmptyTensorElimination> { 192 EmptyTensorElimination() = default; 193 194 void runOnOperation() override; 195 196 void getDependentDialects(DialectRegistry ®istry) const override { 197 registry 198 .insert<bufferization::BufferizationDialect, tensor::TensorDialect>(); 199 } 200 }; 201 } // namespace 202 203 LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter, 204 Operation *op) { 205 auto moduleOp = dyn_cast<ModuleOp>(op); 206 OneShotBufferizationOptions options; 207 options.allowReturnAllocsFromLoops = true; 208 if (moduleOp) 209 options.bufferizeFunctionBoundaries = true; 210 OneShotAnalysisState state(op, options); 211 if (moduleOp) { 212 // Module analysis takes into account function boundaries. 213 if (failed(analyzeModuleOp(moduleOp, state))) 214 return failure(); 215 } else { 216 // Regular One-Shot Bufferize ignores func.func block arguments, func.call, 217 // func.return. 218 if (failed(analyzeOp(op, state))) 219 return failure(); 220 } 221 222 return bufferization::eliminateEmptyTensors(rewriter, op, state); 223 } 224 225 void EmptyTensorElimination::runOnOperation() { 226 IRRewriter rewriter(getOperation()->getContext()); 227 if (failed(bufferization::eliminateEmptyTensors(rewriter, getOperation()))) 228 signalPassFailure(); 229 } 230 231 std::unique_ptr<Pass> mlir::bufferization::createEmptyTensorEliminationPass() { 232 return std::make_unique<EmptyTensorElimination>(); 233 } 234