1 //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 14 #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 using namespace mlir::linalg; 21 22 /// Get an output operand that matches the given input operand and can be used 23 /// to eliminate a tensor.empty op. 24 static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) { 25 for (OpOperand &operand : op.getDpsInitsMutable()) { 26 // Operand must be unused. 27 if (op.payloadUsesValueFromOperand(&operand)) 28 continue; 29 // Types must match. 30 if (operand.get().getType() != in->get().getType()) 31 continue; 32 // Indexing maps must match. 33 if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in)) 34 continue; 35 return &operand; 36 } 37 return nullptr; 38 } 39 40 LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep( 41 RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { 42 OpBuilder::InsertionGuard g(rewriter); 43 DominanceInfo domInfo; 44 45 op->walk([&](LinalgOp op) { 46 // Only ops with all "parallel" iterator types are supported. 47 if (op.getNumParallelLoops() != op.getNumLoops()) 48 return WalkResult::skip(); 49 50 for (OpOperand *in : op.getDpsInputOperands()) { 51 // Skip non-tensor operands. 52 if (!isa<RankedTensorType>(in->get().getType())) 53 continue; 54 55 // Find tensor.empty ops on the reverse SSA use-def chain. Only follow 56 // equivalent tensors. I.e., stop when there are ops such as extract_slice 57 // on the path. 58 TraversalConfig config; 59 config.followEquivalentOnly = true; 60 config.alwaysIncludeLeaves = false; 61 SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( 62 in, /*condition=*/ 63 [&](Value val) { 64 return val.getDefiningOp<tensor::EmptyOp>() && 65 val.getType() == in->get().getType(); 66 }, 67 config); 68 if (emptyTensors.empty()) 69 continue; 70 71 // Find matching out operand. 72 OpOperand *out = getUnusedOutOperand(op, in); 73 if (!out) 74 continue; 75 76 // Check if this transform would violate dominance. 77 if (!llvm::all_of(emptyTensors, [&](Value v) { 78 return domInfo.properlyDominates(out->get(), v.getDefiningOp()); 79 })) 80 continue; 81 82 // Replace all uses of the tensor.empty, but do not delete it yet. It will 83 // fold away later (to not invalidate DominanceInfo). 84 for (Value v : emptyTensors) { 85 assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty"); 86 rewriter.replaceAllUsesWith(v, out->get()); 87 } 88 89 // Turn the "in" into an "out". 90 rewriter.modifyOpInPlace(op, [&]() { 91 out->set(in->get()); 92 // The original "in" could be removed entirely here (because it will no 93 // longer have any uses in the payload), but we delegate this to 94 // existing cleanup patterns that remove unused operands. 95 in->set(emptyTensors.front()); 96 BlockArgument outArg = op.getMatchingBlockArgument(out); 97 assert(outArg.getUses().empty() && "expected that out has no uses"); 98 BlockArgument inArg = op.getMatchingBlockArgument(in); 99 rewriter.replaceAllUsesWith(inArg, outArg); 100 assert(!op.payloadUsesValueFromOperand(in) && 101 "expected that the in operand is now unused"); 102 }); 103 104 state.resetCache(); 105 } 106 107 return WalkResult::advance(); 108 }); 109 return success(); 110 } 111