1180f9ef8SMatthias Springer //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===// 2180f9ef8SMatthias Springer // 3180f9ef8SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4180f9ef8SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5180f9ef8SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6180f9ef8SMatthias Springer // 7180f9ef8SMatthias Springer //===----------------------------------------------------------------------===// 8180f9ef8SMatthias Springer 9180f9ef8SMatthias Springer #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 10180f9ef8SMatthias Springer 11180f9ef8SMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12180f9ef8SMatthias Springer #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13180f9ef8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14180f9ef8SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 15180f9ef8SMatthias Springer #include "mlir/Dialect/Linalg/IR/Linalg.h" 16180f9ef8SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 17180f9ef8SMatthias Springer 18180f9ef8SMatthias Springer using namespace mlir; 19180f9ef8SMatthias Springer using namespace mlir::bufferization; 20180f9ef8SMatthias Springer using namespace mlir::linalg; 21180f9ef8SMatthias Springer 22180f9ef8SMatthias Springer /// Get an output operand that matches the given input operand and can be used 23180f9ef8SMatthias Springer /// to eliminate a tensor.empty op. 24180f9ef8SMatthias Springer static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) { 250b2197b0SMatthias Springer for (OpOperand &operand : op.getDpsInitsMutable()) { 26180f9ef8SMatthias Springer // Operand must be unused. 270b2197b0SMatthias Springer if (op.payloadUsesValueFromOperand(&operand)) 28180f9ef8SMatthias Springer continue; 29180f9ef8SMatthias Springer // Types must match. 300b2197b0SMatthias Springer if (operand.get().getType() != in->get().getType()) 31180f9ef8SMatthias Springer continue; 32180f9ef8SMatthias Springer // Indexing maps must match. 330b2197b0SMatthias Springer if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in)) 34180f9ef8SMatthias Springer continue; 350b2197b0SMatthias Springer return &operand; 36180f9ef8SMatthias Springer } 37180f9ef8SMatthias Springer return nullptr; 38180f9ef8SMatthias Springer } 39180f9ef8SMatthias Springer 40180f9ef8SMatthias Springer LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep( 41180f9ef8SMatthias Springer RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { 42180f9ef8SMatthias Springer OpBuilder::InsertionGuard g(rewriter); 43180f9ef8SMatthias Springer DominanceInfo domInfo; 44180f9ef8SMatthias Springer 45180f9ef8SMatthias Springer op->walk([&](LinalgOp op) { 46180f9ef8SMatthias Springer // Only ops with all "parallel" iterator types are supported. 47180f9ef8SMatthias Springer if (op.getNumParallelLoops() != op.getNumLoops()) 48180f9ef8SMatthias Springer return WalkResult::skip(); 49180f9ef8SMatthias Springer 50180f9ef8SMatthias Springer for (OpOperand *in : op.getDpsInputOperands()) { 51180f9ef8SMatthias Springer // Skip non-tensor operands. 52a5757c5bSChristian Sigg if (!isa<RankedTensorType>(in->get().getType())) 53180f9ef8SMatthias Springer continue; 54180f9ef8SMatthias Springer 55180f9ef8SMatthias Springer // Find tensor.empty ops on the reverse SSA use-def chain. Only follow 56180f9ef8SMatthias Springer // equivalent tensors. I.e., stop when there are ops such as extract_slice 57180f9ef8SMatthias Springer // on the path. 58180f9ef8SMatthias Springer TraversalConfig config; 59180f9ef8SMatthias Springer config.followEquivalentOnly = true; 60180f9ef8SMatthias Springer config.alwaysIncludeLeaves = false; 61180f9ef8SMatthias Springer SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( 62*d9111f19SAmir Bishara in, /*condition=*/ 636b65d79fSSpenser Bauman [&](Value val) { 646b65d79fSSpenser Bauman return val.getDefiningOp<tensor::EmptyOp>() && 656b65d79fSSpenser Bauman val.getType() == in->get().getType(); 666b65d79fSSpenser Bauman }, 67180f9ef8SMatthias Springer config); 68180f9ef8SMatthias Springer if (emptyTensors.empty()) 69180f9ef8SMatthias Springer continue; 70180f9ef8SMatthias Springer 71180f9ef8SMatthias Springer // Find matching out operand. 72180f9ef8SMatthias Springer OpOperand *out = getUnusedOutOperand(op, in); 73180f9ef8SMatthias Springer if (!out) 74180f9ef8SMatthias Springer continue; 75180f9ef8SMatthias Springer 76180f9ef8SMatthias Springer // Check if this transform would violate dominance. 77180f9ef8SMatthias Springer if (!llvm::all_of(emptyTensors, [&](Value v) { 78180f9ef8SMatthias Springer return domInfo.properlyDominates(out->get(), v.getDefiningOp()); 79180f9ef8SMatthias Springer })) 80180f9ef8SMatthias Springer continue; 81180f9ef8SMatthias Springer 82180f9ef8SMatthias Springer // Replace all uses of the tensor.empty, but do not delete it yet. It will 83180f9ef8SMatthias Springer // fold away later (to not invalidate DominanceInfo). 84180f9ef8SMatthias Springer for (Value v : emptyTensors) { 85180f9ef8SMatthias Springer assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty"); 86180f9ef8SMatthias Springer rewriter.replaceAllUsesWith(v, out->get()); 87180f9ef8SMatthias Springer } 88180f9ef8SMatthias Springer 89180f9ef8SMatthias Springer // Turn the "in" into an "out". 905fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, [&]() { 91180f9ef8SMatthias Springer out->set(in->get()); 92180f9ef8SMatthias Springer // The original "in" could be removed entirely here (because it will no 93180f9ef8SMatthias Springer // longer have any uses in the payload), but we delegate this to 94180f9ef8SMatthias Springer // existing cleanup patterns that remove unused operands. 95180f9ef8SMatthias Springer in->set(emptyTensors.front()); 96180f9ef8SMatthias Springer BlockArgument outArg = op.getMatchingBlockArgument(out); 97180f9ef8SMatthias Springer assert(outArg.getUses().empty() && "expected that out has no uses"); 98180f9ef8SMatthias Springer BlockArgument inArg = op.getMatchingBlockArgument(in); 99180f9ef8SMatthias Springer rewriter.replaceAllUsesWith(inArg, outArg); 100180f9ef8SMatthias Springer assert(!op.payloadUsesValueFromOperand(in) && 101180f9ef8SMatthias Springer "expected that the in operand is now unused"); 102180f9ef8SMatthias Springer }); 103180f9ef8SMatthias Springer 104180f9ef8SMatthias Springer state.resetCache(); 105180f9ef8SMatthias Springer } 106180f9ef8SMatthias Springer 107180f9ef8SMatthias Springer return WalkResult::advance(); 108180f9ef8SMatthias Springer }); 109180f9ef8SMatthias Springer return success(); 110180f9ef8SMatthias Springer } 111