100f239e4SArteen Abrishami //===- TosaReduceTransposes.cpp -------------------------------------------===// 200f239e4SArteen Abrishami // 300f239e4SArteen Abrishami // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 400f239e4SArteen Abrishami // See https://llvm.org/LICENSE.txt for license information. 500f239e4SArteen Abrishami // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 600f239e4SArteen Abrishami // 700f239e4SArteen Abrishami //===----------------------------------------------------------------------===// 800f239e4SArteen Abrishami 900f239e4SArteen Abrishami // ---------- 1000f239e4SArteen Abrishami // Motivation: 1100f239e4SArteen Abrishami // ---------- 1200f239e4SArteen Abrishami 1300f239e4SArteen Abrishami // Some legalization pathways introduce redundant tosa.TRANSPOSE 1400f239e4SArteen Abrishami // operations that result in avoidable data movement. For example, 1500f239e4SArteen Abrishami // PyTorch -> TOSA contains a lot of unnecessary transposes due 1600f239e4SArteen Abrishami // to conversions between NCHW and NHWC. 1700f239e4SArteen Abrishami 1800f239e4SArteen Abrishami // We wish to remove all the ones that we can, since in general 1900f239e4SArteen Abrishami // it is possible to remove the overwhelming majority. 2000f239e4SArteen Abrishami 2100f239e4SArteen Abrishami // ------------------- 2200f239e4SArteen Abrishami // High-Level Overview: 2300f239e4SArteen Abrishami // ------------------- 2400f239e4SArteen Abrishami 2500f239e4SArteen Abrishami // The pass works through the transpose operators in the program. It begins at 2600f239e4SArteen Abrishami // some transpose operator with an associated permutations tensor. It traverses 2700f239e4SArteen Abrishami // upwards through the dependencies of this transpose and verifies that we 2800f239e4SArteen Abrishami // encounter only operators with the TosaElementwiseOperator trait and terminate 2900f239e4SArteen Abrishami // in either constants, reshapes, or transposes. 3000f239e4SArteen Abrishami 3100f239e4SArteen Abrishami // We then evaluate whether there are any additional restrictions (the 3200f239e4SArteen Abrishami // transposes it terminates in must invert the one we began at, and the reshapes 3300f239e4SArteen Abrishami // must be ones in which we can fold the transpose into), and then we hoist the 3400f239e4SArteen Abrishami // transpose through the intervening operators, folding it at the constants, 3500f239e4SArteen Abrishami // reshapes, and transposes. 3600f239e4SArteen Abrishami 3700f239e4SArteen Abrishami // Finally, we ensure that we do not need both the transposed form (the form 3800f239e4SArteen Abrishami // that had the transpose hoisted through it) and the untransposed form (which 3900f239e4SArteen Abrishami // it was prior), by analyzing the usages of those dependent operators of a 4000f239e4SArteen Abrishami // given transpose we are attempting to hoist and replace. 4100f239e4SArteen Abrishami 4200f239e4SArteen Abrishami // If they are such that it would require both forms to be necessary, then we do 4300f239e4SArteen Abrishami // not replace the hoisted transpose, causing the new chain to be dead. 4400f239e4SArteen Abrishami // Otherwise, we do and the old chain (untransposed form) becomes dead. Only one 4500f239e4SArteen Abrishami // chain will ever then be live, resulting in no duplication. 4600f239e4SArteen Abrishami 4700f239e4SArteen Abrishami // We then perform a simple one-pass DCE, so no canonicalization is necessary. 4800f239e4SArteen Abrishami 4900f239e4SArteen Abrishami // ----------- 5000f239e4SArteen Abrishami // Future Work: 5100f239e4SArteen Abrishami // ----------- 5200f239e4SArteen Abrishami 5300f239e4SArteen Abrishami // (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across 5400f239e4SArteen Abrishami // hoisted 5500f239e4SArteen Abrishami // transposes with different permutation tensors. 5600f239e4SArteen Abrishami 5700f239e4SArteen Abrishami // (2) Expand the class of foldable upstream ReshapeOp we permit beyond 5800f239e4SArteen Abrishami // N -> 1x1x...x1xNx1x...x1x1. 5900f239e4SArteen Abrishami 6000f239e4SArteen Abrishami // (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond 6100f239e4SArteen Abrishami // those that form the identity. 6200f239e4SArteen Abrishami 6300f239e4SArteen Abrishami // (4) Add support for more instructions besides TosaElementwiseOperator as 6400f239e4SArteen Abrishami // the intervening ones (for example, the reduce_* operators). 6500f239e4SArteen Abrishami 6600f239e4SArteen Abrishami // (5) Support hoisting transposes up to an input parameter. 6700f239e4SArteen Abrishami 6800f239e4SArteen Abrishami //===----------------------------------------------------------------------===// 6900f239e4SArteen Abrishami 7000f239e4SArteen Abrishami #include "mlir/Dialect/Func/IR/FuncOps.h" 7100f239e4SArteen Abrishami #include "mlir/Dialect/Tosa/IR/TosaOps.h" 7200f239e4SArteen Abrishami #include "mlir/Dialect/Tosa/Transforms/Passes.h" 7300f239e4SArteen Abrishami #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 7400f239e4SArteen Abrishami #include "mlir/IR/Iterators.h" 7500f239e4SArteen Abrishami #include "mlir/IR/Matchers.h" 7600f239e4SArteen Abrishami #include "llvm/ADT/TypeSwitch.h" 7700f239e4SArteen Abrishami #include <memory> 7800f239e4SArteen Abrishami #include <set> 7900f239e4SArteen Abrishami #include <stack> 8000f239e4SArteen Abrishami 8100f239e4SArteen Abrishami namespace mlir { 8200f239e4SArteen Abrishami namespace tosa { 8300f239e4SArteen Abrishami #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES 8400f239e4SArteen Abrishami #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 8500f239e4SArteen Abrishami } // namespace tosa 8600f239e4SArteen Abrishami } // namespace mlir 8700f239e4SArteen Abrishami 8800f239e4SArteen Abrishami using namespace mlir; 8900f239e4SArteen Abrishami using namespace mlir::tosa; 9000f239e4SArteen Abrishami 9100f239e4SArteen Abrishami //===----------------------------------------------------------------------===// 9200f239e4SArteen Abrishami // TOSA Reduce Transposes Pass. 9300f239e4SArteen Abrishami //===----------------------------------------------------------------------===// 9400f239e4SArteen Abrishami 9500f239e4SArteen Abrishami namespace { 9600f239e4SArteen Abrishami 9700f239e4SArteen Abrishami struct TosaReduceTransposes final 9800f239e4SArteen Abrishami : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> { 9900f239e4SArteen Abrishami void runOnOperation() override; 10000f239e4SArteen Abrishami 10100f239e4SArteen Abrishami private: 10200f239e4SArteen Abrishami // This will collect all the data dependencies for the given Operation 10300f239e4SArteen Abrishami // up to and including ConstOp, ReshapeOp, and TransposeOp. 10400f239e4SArteen Abrishami bool collectFanIn(Operation *op, SetVector<Operation *> &collected); 10500f239e4SArteen Abrishami bool convertDependentOps(SetVector<Operation *> &dependentOps, 10600f239e4SArteen Abrishami DenseMap<Value, Value> &valuesMap, 10700f239e4SArteen Abrishami IRRewriter &rewriter, 10800f239e4SArteen Abrishami ArrayRef<int32_t> hoistedPerms); 10900f239e4SArteen Abrishami 11000f239e4SArteen Abrishami // Checks if the two permutations, when applied consecutively, result 11100f239e4SArteen Abrishami // in the identity. 11200f239e4SArteen Abrishami bool areInvolutionTransposes(ArrayRef<int32_t> perms1, 11300f239e4SArteen Abrishami ArrayRef<int32_t> perms2); 11400f239e4SArteen Abrishami 11500f239e4SArteen Abrishami // This is meant to apply to operations with the TosaElementwiseOperator 11600f239e4SArteen Abrishami // trait. 11700f239e4SArteen Abrishami std::optional<Value> 11800f239e4SArteen Abrishami buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap, 11900f239e4SArteen Abrishami IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); 12000f239e4SArteen Abrishami 12100f239e4SArteen Abrishami // This updates valuesMap when we encounter another TransposeOp as a 12200f239e4SArteen Abrishami // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to 12300f239e4SArteen Abrishami // this %1 = tosa.transpose %0 <- when tracking back from this 12400f239e4SArteen Abrishami std::optional<Value> 12500f239e4SArteen Abrishami buildMappedToValue(TransposeOp transposeOp, 12600f239e4SArteen Abrishami const DenseMap<Value, Value> &valuesMap, 12700f239e4SArteen Abrishami IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); 12800f239e4SArteen Abrishami 12900f239e4SArteen Abrishami // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so, 13000f239e4SArteen Abrishami // it creates new ReshapeOp with that fold. 13100f239e4SArteen Abrishami std::optional<Value> 13200f239e4SArteen Abrishami buildMappedToValue(ReshapeOp reshapeOp, 13300f239e4SArteen Abrishami const DenseMap<Value, Value> &valuesMap, 13400f239e4SArteen Abrishami IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); 13500f239e4SArteen Abrishami 13600f239e4SArteen Abrishami // We may have something like: 13700f239e4SArteen Abrishami // %0 = tosa.const 13800f239e4SArteen Abrishami // %1 = tosa.transpose 13900f239e4SArteen Abrishami // %2 = tosa.add %0, %1 14000f239e4SArteen Abrishami // %3 = tosa.transpose %2 14100f239e4SArteen Abrishami // that --tosa-layerwise-const-fold wouldn't handle. This use shows up 14200f239e4SArteen Abrishami // in MobilenetV3. 14300f239e4SArteen Abrishami std::optional<Value> 14400f239e4SArteen Abrishami buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap, 14500f239e4SArteen Abrishami IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); 14600f239e4SArteen Abrishami 14700f239e4SArteen Abrishami // Checks which TransposeOp we should "replace", turning their converted 14800f239e4SArteen Abrishami // chains of ops, through which they were propagated, "live", and the old code 14900f239e4SArteen Abrishami // "dead." Attempts to avoid doing so when doing so would result in the old 15000f239e4SArteen Abrishami // code staying "live," resulting in duplication. 15100f239e4SArteen Abrishami std::set<TransposeOp> getGoodReplacements( 15200f239e4SArteen Abrishami ArrayRef<int32_t> perms, 15300f239e4SArteen Abrishami std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 15400f239e4SArteen Abrishami &transposeInfo); 15500f239e4SArteen Abrishami 15600f239e4SArteen Abrishami // Helper function for dependenciesAreValid. 15700f239e4SArteen Abrishami bool userNotContainedInValidTransposeDependencies( 15800f239e4SArteen Abrishami Operation *user, std::set<TransposeOp> &validTransposes, 15900f239e4SArteen Abrishami std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 16000f239e4SArteen Abrishami &transposeInfo); 16100f239e4SArteen Abrishami 16200f239e4SArteen Abrishami // Helper function for getGoodReplacements to check if some TransposeOp's 16300f239e4SArteen Abrishami // dependencies are OK. 16400f239e4SArteen Abrishami bool dependenciesAreValid( 16500f239e4SArteen Abrishami ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps, 16600f239e4SArteen Abrishami std::set<TransposeOp> &validTransposes, 16700f239e4SArteen Abrishami std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 16800f239e4SArteen Abrishami &transposeInfo); 16900f239e4SArteen Abrishami 17000f239e4SArteen Abrishami // Applies perms to the DenseElementsAttr. 17100f239e4SArteen Abrishami // If it returns std::nullopt, it also triggers pass failure, since verifier 17200f239e4SArteen Abrishami // guarantees from TOSA are not in place (and otherwise, if used elsewhere, 17300f239e4SArteen Abrishami // it should fail). 17400f239e4SArteen Abrishami // This is a basic API and may benefit from refactor into the core MLIR APIs. 17500f239e4SArteen Abrishami std::optional<DenseElementsAttr> 17600f239e4SArteen Abrishami transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms); 17700f239e4SArteen Abrishami }; 17800f239e4SArteen Abrishami 17900f239e4SArteen Abrishami std::optional<DenseElementsAttr> 18000f239e4SArteen Abrishami TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input, 18100f239e4SArteen Abrishami ArrayRef<int32_t> perms) { 18200f239e4SArteen Abrishami RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType()); 18300f239e4SArteen Abrishami RankedTensorType newType = 18400f239e4SArteen Abrishami RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms), 18500f239e4SArteen Abrishami oldType.getElementType()); 18600f239e4SArteen Abrishami size_t rank = oldType.getRank(); 18700f239e4SArteen Abrishami 18800f239e4SArteen Abrishami // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension 18900f239e4SArteen Abrishami // 0. If not in place, something is very wrong. 190*a58e774fSJack Frankland if (rank <= 0 || oldType.getNumElements() <= 0) { 19100f239e4SArteen Abrishami signalPassFailure(); 19200f239e4SArteen Abrishami return std::nullopt; 19300f239e4SArteen Abrishami } 19400f239e4SArteen Abrishami 19500f239e4SArteen Abrishami if (input.isSplat()) 19600f239e4SArteen Abrishami return input.reshape(newType); 19700f239e4SArteen Abrishami 19800f239e4SArteen Abrishami // The algorithm is approximately as follows: 19900f239e4SArteen Abrishami // input: perms, input flat array, input tensor type 20000f239e4SArteen Abrishami // (1/2) determine the strides of input/output if 20100f239e4SArteen Abrishami // they were strided in row-major order. (3) adjust the strides for the 20200f239e4SArteen Abrishami // input to be in the same order of indices as the output is written. 20300f239e4SArteen Abrishami // (4) process dimension by dimension. example: perms 2, 0, 1; input 20400f239e4SArteen Abrishami // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] = 20500f239e4SArteen Abrishami // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust 20600f239e4SArteen Abrishami // input strides to be as input[i + 12j + 4k] so we may process 20700f239e4SArteen Abrishami // layer-by-layer. 20800f239e4SArteen Abrishami 20900f239e4SArteen Abrishami // Step 1/2: Strides for input. We ignore output since row-major and can just 21000f239e4SArteen Abrishami // push_back. 21100f239e4SArteen Abrishami 21200f239e4SArteen Abrishami SmallVector<int64_t> originalInputStrides(rank); 21300f239e4SArteen Abrishami originalInputStrides[rank - 1] = 1; 21400f239e4SArteen Abrishami // index with int64_t to avoid overflow 21500f239e4SArteen Abrishami for (int64_t i = rank - 2; i >= 0; i--) 21600f239e4SArteen Abrishami originalInputStrides[i] = 21700f239e4SArteen Abrishami originalInputStrides[i + 1] * oldType.getDimSize(i + 1); 21800f239e4SArteen Abrishami 21900f239e4SArteen Abrishami // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as 22000f239e4SArteen Abrishami // output which is done in row-major order. 22100f239e4SArteen Abrishami 22200f239e4SArteen Abrishami SmallVector<int64_t> newInputStrides; 22300f239e4SArteen Abrishami newInputStrides.reserve(rank); 22400f239e4SArteen Abrishami for (int32_t v : perms) 22500f239e4SArteen Abrishami newInputStrides.push_back(originalInputStrides[v]); 22600f239e4SArteen Abrishami 22700f239e4SArteen Abrishami // Step 4: Write out the transposed "flat array" dimension by dimension. 22800f239e4SArteen Abrishami 22900f239e4SArteen Abrishami auto inputArray = input.getValues<Attribute>(); 23000f239e4SArteen Abrishami SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides; 23100f239e4SArteen Abrishami for (size_t i = 0; i < rank; i++) 23200f239e4SArteen Abrishami boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]}); 23300f239e4SArteen Abrishami 23400f239e4SArteen Abrishami SmallVector<Attribute> resultArray; 23500f239e4SArteen Abrishami resultArray.reserve(inputArray.size()); 23600f239e4SArteen Abrishami 23700f239e4SArteen Abrishami std::function<void(int64_t, 23800f239e4SArteen Abrishami SmallVector<std::pair<int64_t, int64_t>>::const_iterator)> 23900f239e4SArteen Abrishami processTransposeDim = [&](auto accumulatedIndex, auto it) { 24000f239e4SArteen Abrishami if (it == boundsAndStrides.end()) { 24100f239e4SArteen Abrishami resultArray.push_back(inputArray[accumulatedIndex]); 24200f239e4SArteen Abrishami return; 24300f239e4SArteen Abrishami } 24400f239e4SArteen Abrishami 24500f239e4SArteen Abrishami for (int64_t i = 0; i < it->first; i++) { 24600f239e4SArteen Abrishami int64_t j = accumulatedIndex + i * it->second; 24700f239e4SArteen Abrishami processTransposeDim(j, it + 1); 24800f239e4SArteen Abrishami } 24900f239e4SArteen Abrishami }; 25000f239e4SArteen Abrishami 25100f239e4SArteen Abrishami processTransposeDim(0, boundsAndStrides.begin()); 25200f239e4SArteen Abrishami 25300f239e4SArteen Abrishami return DenseElementsAttr::get(newType, resultArray); 25400f239e4SArteen Abrishami } 25500f239e4SArteen Abrishami 25600f239e4SArteen Abrishami // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp 25700f239e4SArteen Abrishami // as the sources of the data dependencies, and TosaElementWiseOperator 25800f239e4SArteen Abrishami // after that, if the function returns true. 25900f239e4SArteen Abrishami bool TosaReduceTransposes::collectFanIn(Operation *op, 26000f239e4SArteen Abrishami SetVector<Operation *> &collected) { 26100f239e4SArteen Abrishami // Can occur if defined through the parameter to a func.func. 26200f239e4SArteen Abrishami if (!op) 26300f239e4SArteen Abrishami return false; 26400f239e4SArteen Abrishami 26500f239e4SArteen Abrishami if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect())) 26600f239e4SArteen Abrishami return false; 26700f239e4SArteen Abrishami 26800f239e4SArteen Abrishami // Prevent extra work if already seen. 26900f239e4SArteen Abrishami if (collected.contains(op)) 27000f239e4SArteen Abrishami return true; 27100f239e4SArteen Abrishami 27200f239e4SArteen Abrishami // Throw it out so later don't have to deal with this. 27300f239e4SArteen Abrishami if (op->getNumResults() != 1 || 27400f239e4SArteen Abrishami !llvm::isa<RankedTensorType>(op->getResult(0).getType())) 27500f239e4SArteen Abrishami return false; 27600f239e4SArteen Abrishami 27700f239e4SArteen Abrishami // We don't wish to traverse up a ReshapeOp, since generally we can't 27800f239e4SArteen Abrishami // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp 27900f239e4SArteen Abrishami // will have no in-edges in the data dependency graph we construct for 28000f239e4SArteen Abrishami // the downstream TransposeOp. 28100f239e4SArteen Abrishami if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) && 28200f239e4SArteen Abrishami !llvm::isa<tosa::ConstOp>(op)) { 28300f239e4SArteen Abrishami 28400f239e4SArteen Abrishami if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()) 28500f239e4SArteen Abrishami return false; 28600f239e4SArteen Abrishami 28700f239e4SArteen Abrishami for (Value operand : op->getOperands()) 28800f239e4SArteen Abrishami // If this is a problem in future, think about alternatives to recursion. 28900f239e4SArteen Abrishami if (!collectFanIn(operand.getDefiningOp(), collected)) 29000f239e4SArteen Abrishami return false; 29100f239e4SArteen Abrishami } 29200f239e4SArteen Abrishami 29300f239e4SArteen Abrishami // Insert in topological order. 29400f239e4SArteen Abrishami collected.insert(op); 29500f239e4SArteen Abrishami 29600f239e4SArteen Abrishami return true; 29700f239e4SArteen Abrishami } 29800f239e4SArteen Abrishami 29900f239e4SArteen Abrishami // Assuming that due to the verification of TransposeOp perms arrays are 30000f239e4SArteen Abrishami // permutations of 0 - perms.size() - 1. 30100f239e4SArteen Abrishami bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1, 30200f239e4SArteen Abrishami ArrayRef<int32_t> perms2) { 30300f239e4SArteen Abrishami if (perms1.size() != perms2.size()) 30400f239e4SArteen Abrishami return false; 30500f239e4SArteen Abrishami int32_t n = perms1.size(); 30600f239e4SArteen Abrishami for (int32_t i = 0; i < n; i++) 30700f239e4SArteen Abrishami if (perms2[perms1[i]] != i) 30800f239e4SArteen Abrishami return false; 30900f239e4SArteen Abrishami return true; 31000f239e4SArteen Abrishami } 31100f239e4SArteen Abrishami 31200f239e4SArteen Abrishami // Primary overload for those with TosaElementwiseOperator trait. 31300f239e4SArteen Abrishami // The other ones handle the case of the operations that occur at the 31400f239e4SArteen Abrishami // roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp). 31500f239e4SArteen Abrishami std::optional<Value> TosaReduceTransposes::buildMappedToValue( 31600f239e4SArteen Abrishami Operation *op, const DenseMap<Value, Value> &valuesMap, 31700f239e4SArteen Abrishami IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 31800f239e4SArteen Abrishami if (op->getNumResults() != 1 || 31900f239e4SArteen Abrishami !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()) 32000f239e4SArteen Abrishami return std::nullopt; 32100f239e4SArteen Abrishami 32200f239e4SArteen Abrishami auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType()); 32300f239e4SArteen Abrishami SmallVector<Value, 3> operands; 32400f239e4SArteen Abrishami for (Value v : op->getOperands()) { 32500f239e4SArteen Abrishami if (valuesMap.contains(v)) { 32600f239e4SArteen Abrishami operands.push_back(valuesMap.at(v)); 32700f239e4SArteen Abrishami } else { 32800f239e4SArteen Abrishami return std::nullopt; 32900f239e4SArteen Abrishami } 33000f239e4SArteen Abrishami } 33100f239e4SArteen Abrishami 33200f239e4SArteen Abrishami // Conceptually, we propagate the hoisted TransposeOp through 33300f239e4SArteen Abrishami // these interveaning operations. For example, 33400f239e4SArteen Abrishami 33500f239e4SArteen Abrishami // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32> 33600f239e4SArteen Abrishami // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) -> 33700f239e4SArteen Abrishami // tensor<3x2xi32> 33800f239e4SArteen Abrishami 33900f239e4SArteen Abrishami // becomes: 34000f239e4SArteen Abrishami // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) -> 34100f239e4SArteen Abrishami // tensor<3x2xi32> 34200f239e4SArteen Abrishami // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>) 34300f239e4SArteen Abrishami 34400f239e4SArteen Abrishami // We construct this new tosa.clamp here, but it doesn't 34500f239e4SArteen Abrishami // turn "live" until the transpose being hoisted through this chain 34600f239e4SArteen Abrishami // is replaced with the proper value from the new chain. 34700f239e4SArteen Abrishami 34800f239e4SArteen Abrishami return rewriter 34900f239e4SArteen Abrishami .create(op->getLoc(), op->getName().getIdentifier(), operands, 35000f239e4SArteen Abrishami RankedTensorType::get( 35100f239e4SArteen Abrishami applyTOSAPermutation(resultType.getShape(), hoistedPerms), 35200f239e4SArteen Abrishami resultType.getElementType()), 35300f239e4SArteen Abrishami op->getAttrs()) 35400f239e4SArteen Abrishami ->getResult(0); 35500f239e4SArteen Abrishami } 35600f239e4SArteen Abrishami 35700f239e4SArteen Abrishami std::optional<Value> TosaReduceTransposes::buildMappedToValue( 35800f239e4SArteen Abrishami TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap, 35900f239e4SArteen Abrishami IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 36000f239e4SArteen Abrishami SmallVector<int32_t> perms; 36100f239e4SArteen Abrishami if (failed(transposeOp.getConstantPerms(perms)) || 36200f239e4SArteen Abrishami !areInvolutionTransposes(hoistedPerms, perms)) 36300f239e4SArteen Abrishami return std::nullopt; 36400f239e4SArteen Abrishami return transposeOp.getInput1(); 36500f239e4SArteen Abrishami } 36600f239e4SArteen Abrishami 36700f239e4SArteen Abrishami std::optional<Value> TosaReduceTransposes::buildMappedToValue( 36800f239e4SArteen Abrishami ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap, 36900f239e4SArteen Abrishami IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 37000f239e4SArteen Abrishami auto reshapeOutput = reshapeOp.getOutput(); 37100f239e4SArteen Abrishami auto reshapeInputType = 37200f239e4SArteen Abrishami llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType()); 37300f239e4SArteen Abrishami auto reshapeInputShape = reshapeInputType.getShape(); 37400f239e4SArteen Abrishami // want reshape N -> 1x1x...x1xNx1x...x1x1 37500f239e4SArteen Abrishami if (!reshapeInputType || reshapeInputShape.size() != 1) 37600f239e4SArteen Abrishami return std::nullopt; 37700f239e4SArteen Abrishami auto reshapeOutputType = 37800f239e4SArteen Abrishami llvm::cast<RankedTensorType>(reshapeOutput.getType()); 37900f239e4SArteen Abrishami 38000f239e4SArteen Abrishami // Instead of inserting a TransposeOp here, we check if we can fold it into 38100f239e4SArteen Abrishami // the ReshapeOp. There is more complex cases where this is possible, and 38200f239e4SArteen Abrishami // this check can be extended. 38300f239e4SArteen Abrishami 38400f239e4SArteen Abrishami // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1 38500f239e4SArteen Abrishami auto shape = reshapeOutputType.getShape(); 38600f239e4SArteen Abrishami size_t ones = llvm::count(shape, 1); 38700f239e4SArteen Abrishami // N == 1 and N != 1 38800f239e4SArteen Abrishami if (ones != shape.size() - 1 && 38900f239e4SArteen Abrishami !(ones == shape.size() && reshapeInputShape[0] == 1)) 39000f239e4SArteen Abrishami return std::nullopt; 39100f239e4SArteen Abrishami 39200f239e4SArteen Abrishami // Do not insert a TransposeOp, instead we fold the reshape and its attribute. 39300f239e4SArteen Abrishami auto foldedReshape = rewriter.create<ReshapeOp>( 39400f239e4SArteen Abrishami reshapeOp.getLoc(), 39500f239e4SArteen Abrishami RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms), 39600f239e4SArteen Abrishami reshapeOutputType.getElementType()), 39700f239e4SArteen Abrishami reshapeOp.getInput1(), 39800f239e4SArteen Abrishami rewriter.getDenseI64ArrayAttr( 39900f239e4SArteen Abrishami applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms))); 40000f239e4SArteen Abrishami return foldedReshape->getResult(0); 40100f239e4SArteen Abrishami } 40200f239e4SArteen Abrishami 40300f239e4SArteen Abrishami std::optional<Value> TosaReduceTransposes::buildMappedToValue( 40400f239e4SArteen Abrishami ConstOp constOp, const DenseMap<Value, Value> &valuesMap, 40500f239e4SArteen Abrishami IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 40600f239e4SArteen Abrishami auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue()); 40700f239e4SArteen Abrishami if (!denseAttr) 40800f239e4SArteen Abrishami return std::nullopt; 40900f239e4SArteen Abrishami auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms); 41000f239e4SArteen Abrishami if (!maybeNewDenseAttr.has_value()) 41100f239e4SArteen Abrishami return std::nullopt; 41200f239e4SArteen Abrishami auto newDenseAttr = maybeNewDenseAttr.value(); 41300f239e4SArteen Abrishami auto newConstOp = rewriter.create<ConstOp>( 41400f239e4SArteen Abrishami constOp.getLoc(), newDenseAttr.getType(), newDenseAttr); 41500f239e4SArteen Abrishami return newConstOp->getResult(0); 41600f239e4SArteen Abrishami } 41700f239e4SArteen Abrishami 41800f239e4SArteen Abrishami bool TosaReduceTransposes::convertDependentOps( 41900f239e4SArteen Abrishami SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap, 42000f239e4SArteen Abrishami IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 42100f239e4SArteen Abrishami 42200f239e4SArteen Abrishami for (Operation *op : dependentOps) { 42300f239e4SArteen Abrishami if (!op || op->getNumResults() != 1) 42400f239e4SArteen Abrishami return false; 42500f239e4SArteen Abrishami 42600f239e4SArteen Abrishami Value priorValue = op->getResult(0); 42700f239e4SArteen Abrishami 42800f239e4SArteen Abrishami // It's possible on a prior transposeOp we had the same dependency and 42900f239e4SArteen Abrishami // already resolved it. 43000f239e4SArteen Abrishami if (valuesMap.contains(priorValue)) 43100f239e4SArteen Abrishami continue; 43200f239e4SArteen Abrishami 43300f239e4SArteen Abrishami // Keep converted ops close to the original. 43400f239e4SArteen Abrishami rewriter.setInsertionPointAfter(op); 43500f239e4SArteen Abrishami 43600f239e4SArteen Abrishami std::optional<Value> maybeValue = 43700f239e4SArteen Abrishami llvm::TypeSwitch<Operation *, std::optional<Value>>(op) 43800f239e4SArteen Abrishami .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) { 43900f239e4SArteen Abrishami return buildMappedToValue(transposeOp, valuesMap, rewriter, 44000f239e4SArteen Abrishami hoistedPerms); 44100f239e4SArteen Abrishami }) 44200f239e4SArteen Abrishami .Default([&](Operation *op) { 44300f239e4SArteen Abrishami return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms); 44400f239e4SArteen Abrishami }); 44500f239e4SArteen Abrishami 44600f239e4SArteen Abrishami if (!maybeValue.has_value()) 44700f239e4SArteen Abrishami return false; 44800f239e4SArteen Abrishami 44900f239e4SArteen Abrishami valuesMap[priorValue] = maybeValue.value(); 45000f239e4SArteen Abrishami } 45100f239e4SArteen Abrishami 45200f239e4SArteen Abrishami return true; 45300f239e4SArteen Abrishami } 45400f239e4SArteen Abrishami 45500f239e4SArteen Abrishami bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies( 45600f239e4SArteen Abrishami Operation *user, std::set<TransposeOp> &validTransposes, 45700f239e4SArteen Abrishami std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 45800f239e4SArteen Abrishami &transposeInfo) { 45900f239e4SArteen Abrishami return llvm::none_of( 46000f239e4SArteen Abrishami transposeInfo, 46100f239e4SArteen Abrishami [&validTransposes, 46200f239e4SArteen Abrishami user](const std::pair<TransposeOp, SetVector<Operation *>> &info) { 46300f239e4SArteen Abrishami const auto &[transposeOp, dependentOps] = info; 46400f239e4SArteen Abrishami return validTransposes.count(transposeOp) && 46500f239e4SArteen Abrishami dependentOps.contains(user); 46600f239e4SArteen Abrishami }); 46700f239e4SArteen Abrishami } 46800f239e4SArteen Abrishami 46900f239e4SArteen Abrishami // Dependencies are valid for an operation if none of them occur outside 47000f239e4SArteen Abrishami // of the proper fan-in cones of the hoisted TransposeOp with the same perms 47100f239e4SArteen Abrishami // that we can replace. Described in more detail within. 47200f239e4SArteen Abrishami bool TosaReduceTransposes::dependenciesAreValid( 47300f239e4SArteen Abrishami ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps, 47400f239e4SArteen Abrishami std::set<TransposeOp> &validTransposes, 47500f239e4SArteen Abrishami std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 47600f239e4SArteen Abrishami &transposeInfo) { 47700f239e4SArteen Abrishami for (Operation *op : dependentOps) { 47800f239e4SArteen Abrishami 47900f239e4SArteen Abrishami // It's OK wherever ConstOp has uses -- in the worst case, we duplicate. 48000f239e4SArteen Abrishami // This can be changed later if we find the memory impact is too high. 48100f239e4SArteen Abrishami if (llvm::isa<ConstOp>(op)) 48200f239e4SArteen Abrishami continue; 48300f239e4SArteen Abrishami 48400f239e4SArteen Abrishami for (OpOperand &use : op->getUses()) { 48500f239e4SArteen Abrishami // Want the uses to be (1) contained in the dependentOps of other 48600f239e4SArteen Abrishami // validTransposes, or (2) to be directly used in a TransposeOp with the 48700f239e4SArteen Abrishami // same perms. For (2) it means the fan-in is a subset of our 48800f239e4SArteen Abrishami // dependentOps, so it is also a validTranspose that will eventually be 48900f239e4SArteen Abrishami // replaced. 49000f239e4SArteen Abrishami Operation *user = use.getOwner(); 49100f239e4SArteen Abrishami if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) { 49200f239e4SArteen Abrishami SmallVector<int32_t> otherPerms; 49300f239e4SArteen Abrishami 49400f239e4SArteen Abrishami // Can later think about cases where transpose -> transpose 49500f239e4SArteen Abrishami // or reshape -> transpose, where the transposes are not necessarily 49600f239e4SArteen Abrishami // the same perms as the hoisted, if implementing a more general 49700f239e4SArteen Abrishami // transform. These could be permitted. 49800f239e4SArteen Abrishami if (failed(otherTranspose.getConstantPerms(otherPerms)) || 49900f239e4SArteen Abrishami !llvm::equal(perms, otherPerms)) 50000f239e4SArteen Abrishami return false; 50100f239e4SArteen Abrishami } else if (userNotContainedInValidTransposeDependencies( 50200f239e4SArteen Abrishami user, validTransposes, transposeInfo)) { 50300f239e4SArteen Abrishami return false; 50400f239e4SArteen Abrishami } 50500f239e4SArteen Abrishami } 50600f239e4SArteen Abrishami } 50700f239e4SArteen Abrishami 50800f239e4SArteen Abrishami return true; 50900f239e4SArteen Abrishami } 51000f239e4SArteen Abrishami 51100f239e4SArteen Abrishami // Getting the set of TransposeOp that we can replace without causing 51200f239e4SArteen Abrishami // the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being 51300f239e4SArteen Abrishami // dead code. This is done by iterating the set until convergence, since 51400f239e4SArteen Abrishami // if you are used outside your own fan-in cone, it's possible to be used 51500f239e4SArteen Abrishami // in another fan-in cone of a TransposeOp that is being replaced -- unless 51600f239e4SArteen Abrishami // we find that that one has a usage outside of it too. 51700f239e4SArteen Abrishami std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements( 51800f239e4SArteen Abrishami ArrayRef<int32_t> perms, 51900f239e4SArteen Abrishami std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 52000f239e4SArteen Abrishami &transposeInfo) { 52100f239e4SArteen Abrishami // Initially, we assume they are all good to replace, 52200f239e4SArteen Abrishami // and we whittle them down based on our criteria. 52300f239e4SArteen Abrishami std::set<TransposeOp> ableToReplace; 52400f239e4SArteen Abrishami for (const auto &[transposeOp, _] : transposeInfo) 52500f239e4SArteen Abrishami ableToReplace.insert(transposeOp); 52600f239e4SArteen Abrishami 52700f239e4SArteen Abrishami bool gotRid; 52800f239e4SArteen Abrishami do { 52900f239e4SArteen Abrishami gotRid = false; 53000f239e4SArteen Abrishami for (const auto &[transposeOp, dependentOps] : transposeInfo) { 53100f239e4SArteen Abrishami // We don't care about it. Already invalidated. 53200f239e4SArteen Abrishami if (!ableToReplace.count(transposeOp)) 53300f239e4SArteen Abrishami continue; 53400f239e4SArteen Abrishami 53500f239e4SArteen Abrishami // Check for validity. 53600f239e4SArteen Abrishami if (!dependenciesAreValid(perms, dependentOps, ableToReplace, 53700f239e4SArteen Abrishami transposeInfo)) { 53800f239e4SArteen Abrishami ableToReplace.erase(transposeOp); 53900f239e4SArteen Abrishami gotRid = true; 54000f239e4SArteen Abrishami break; 54100f239e4SArteen Abrishami } 54200f239e4SArteen Abrishami } 54300f239e4SArteen Abrishami 54400f239e4SArteen Abrishami } while (gotRid); 54500f239e4SArteen Abrishami 54600f239e4SArteen Abrishami return ableToReplace; 54700f239e4SArteen Abrishami } 54800f239e4SArteen Abrishami 54900f239e4SArteen Abrishami void TosaReduceTransposes::runOnOperation() { 55000f239e4SArteen Abrishami // We want to operate only within a single block. 55100f239e4SArteen Abrishami if (!getOperation().getRegion().hasOneBlock()) 55200f239e4SArteen Abrishami return; 55300f239e4SArteen Abrishami 55400f239e4SArteen Abrishami IRRewriter rewriter(&getContext()); 55500f239e4SArteen Abrishami // For each perms, maintain a mapping for converted ops, avoid duplication. 55600f239e4SArteen Abrishami DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues; 55700f239e4SArteen Abrishami // For each perms, we keep track of which TransposeOp are eligible 55800f239e4SArteen Abrishami // for replacement alongside their dependentOps. 55900f239e4SArteen Abrishami DenseMap<ArrayRef<int32_t>, 56000f239e4SArteen Abrishami std::vector<std::pair<TransposeOp, SetVector<Operation *>>>> 56100f239e4SArteen Abrishami permsToTransposeInfo; 56200f239e4SArteen Abrishami 56300f239e4SArteen Abrishami // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef. 56400f239e4SArteen Abrishami // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise 56500f239e4SArteen Abrishami // since no guarantee of smallness. 56600f239e4SArteen Abrishami std::vector<SmallVector<int32_t>> collectedPerms; 56700f239e4SArteen Abrishami 56800f239e4SArteen Abrishami // This keeps track of the order across all eligible-for-replacement 56900f239e4SArteen Abrishami // TransposeOp and their perms, a necessity for the final replacements. 57000f239e4SArteen Abrishami std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder; 57100f239e4SArteen Abrishami 57200f239e4SArteen Abrishami // We want to reserve the space up front, since SmallVector stores some data 57300f239e4SArteen Abrishami // internally and the ArrayRef can reference that, which we don't want to get 57400f239e4SArteen Abrishami // invalidated. 57500f239e4SArteen Abrishami size_t expectedMaxPerms = 0; 57600f239e4SArteen Abrishami getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; }); 57700f239e4SArteen Abrishami collectedPerms.reserve(expectedMaxPerms); 57800f239e4SArteen Abrishami 57900f239e4SArteen Abrishami getOperation().walk([&](TransposeOp transposeOp) { 58000f239e4SArteen Abrishami SetVector<Operation *> dependentOps; 58100f239e4SArteen Abrishami collectedPerms.emplace_back(); 58200f239e4SArteen Abrishami SmallVector<int32_t> &perms = collectedPerms.back(); 58300f239e4SArteen Abrishami 58400f239e4SArteen Abrishami // Dynamic shapes are OK, but the incompatible ones will be rejected later. 58500f239e4SArteen Abrishami auto input = transposeOp.getInput1(); 58600f239e4SArteen Abrishami auto output = transposeOp.getOutput(); 58700f239e4SArteen Abrishami 58800f239e4SArteen Abrishami // However, we don't support unranked tensors. 58900f239e4SArteen Abrishami if (!llvm::isa<RankedTensorType>(input.getType()) || 59000f239e4SArteen Abrishami !llvm::isa<RankedTensorType>(output.getType())) 59100f239e4SArteen Abrishami return; 59200f239e4SArteen Abrishami 59300f239e4SArteen Abrishami // No transformation when transpose permutation non-constant. 59400f239e4SArteen Abrishami if (failed(transposeOp.getConstantPerms(perms))) 59500f239e4SArteen Abrishami return; 59600f239e4SArteen Abrishami 59700f239e4SArteen Abrishami // We let --canonicalize deal with identity transpose. 59800f239e4SArteen Abrishami if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms)) 59900f239e4SArteen Abrishami return; 60000f239e4SArteen Abrishami 60100f239e4SArteen Abrishami // Can fail if some set of basic invariants is not met that we want to 60200f239e4SArteen Abrishami // perform our conversions. 60300f239e4SArteen Abrishami if (!collectFanIn(input.getDefiningOp(), dependentOps)) 60400f239e4SArteen Abrishami return; 60500f239e4SArteen Abrishami 60600f239e4SArteen Abrishami // Want to associate valuesMap for already converted of the same perms, 60700f239e4SArteen Abrishami // since it's possible multiple hoisted transposes w/ different perms 60800f239e4SArteen Abrishami // converge on an op, which would result in different transformations. 60900f239e4SArteen Abrishami DenseMap<Value, Value> &valuesMap = permsToValues[perms]; 61000f239e4SArteen Abrishami 61100f239e4SArteen Abrishami // Attempt to perform the conversions and placements into IR 61200f239e4SArteen Abrishami // without turning inserted code "live". Also fills out valuesMap. 61300f239e4SArteen Abrishami // Fails if there is an intermediary we do not support. 61400f239e4SArteen Abrishami if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms)) 61500f239e4SArteen Abrishami // Some additional operations may have been inserted, but will be 61600f239e4SArteen Abrishami // removed by dead code elimination. 61700f239e4SArteen Abrishami return; 61800f239e4SArteen Abrishami 61900f239e4SArteen Abrishami // This should not happen. If it does -- it's unexpected, 62000f239e4SArteen Abrishami // so we fail the pass. 62100f239e4SArteen Abrishami if (!valuesMap.contains(input)) 62200f239e4SArteen Abrishami return signalPassFailure(); 62300f239e4SArteen Abrishami 62400f239e4SArteen Abrishami // It's possible the types are not compatible (because of dynamic shapes), 62500f239e4SArteen Abrishami // and in these cases, want to resolve dynamic shapes before running the 62600f239e4SArteen Abrishami // pass. 62700f239e4SArteen Abrishami if (output.getType() != valuesMap.at(input).getType()) 62800f239e4SArteen Abrishami return; 62900f239e4SArteen Abrishami 63000f239e4SArteen Abrishami auto &transposeInfo = permsToTransposeInfo[perms]; 63100f239e4SArteen Abrishami 63200f239e4SArteen Abrishami // In general, we might also want to introduce "newDependentOps" 63300f239e4SArteen Abrishami // if there are new usages that don't fall inside the original fan-ins 63400f239e4SArteen Abrishami // (like the TransposeOp we insert for ReshapeOp), 63500f239e4SArteen Abrishami // but in this case, that is specialized enough and overlaps 63600f239e4SArteen Abrishami // with another direct-use TransposeOp case we need to cover anyway. 63700f239e4SArteen Abrishami transposeInfo.push_back({transposeOp, dependentOps}); 63800f239e4SArteen Abrishami 63900f239e4SArteen Abrishami // This is for the final replacement across all transposes. 64000f239e4SArteen Abrishami totalTransposeOrder.push({transposeOp, perms}); 64100f239e4SArteen Abrishami }); 64200f239e4SArteen Abrishami 64300f239e4SArteen Abrishami // We want to do a full fan-in analysis on a perms-level, 64400f239e4SArteen Abrishami // since if we do it on a multi-perms level, and they share (due to a shared 64500f239e4SArteen Abrishami // dependency on a Reshape) then we would also get duplicate ops. 64600f239e4SArteen Abrishami // Const is special cased. 64700f239e4SArteen Abrishami std::set<TransposeOp> ableToReplace; 64800f239e4SArteen Abrishami for (auto &[perms, transposeInfo] : permsToTransposeInfo) { 64900f239e4SArteen Abrishami // Gives us back replacements that would never result in any duplicate 65000f239e4SArteen Abrishami // operations being inserted by us in the IR (i.e, our goal is only to 65100f239e4SArteen Abrishami // remove transposes, and not create a "new chain" to do so, but replace 65200f239e4SArteen Abrishami // the existing chains). 65300f239e4SArteen Abrishami // Ideally, --canonicalize is run before this pass, since it helps this 65400f239e4SArteen Abrishami // analysis by removing dead code to allow more potentially acceptable 65500f239e4SArteen Abrishami // transformations. 65600f239e4SArteen Abrishami auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo); 65700f239e4SArteen Abrishami ableToReplace.insert(goodReplacementsForPerms.begin(), 65800f239e4SArteen Abrishami goodReplacementsForPerms.end()); 65900f239e4SArteen Abrishami } 66000f239e4SArteen Abrishami 66100f239e4SArteen Abrishami // We want to do replacement across all transposes 66200f239e4SArteen Abrishami // in reverse order, due to invalidation of valuesMap mappings 66300f239e4SArteen Abrishami // if we did it otherwise. 66400f239e4SArteen Abrishami while (!totalTransposeOrder.empty()) { 66500f239e4SArteen Abrishami auto [transposeOp, perms] = totalTransposeOrder.top(); 66600f239e4SArteen Abrishami totalTransposeOrder.pop(); 66700f239e4SArteen Abrishami 66800f239e4SArteen Abrishami if (ableToReplace.count(transposeOp) == 0) 66900f239e4SArteen Abrishami continue; 67000f239e4SArteen Abrishami 67100f239e4SArteen Abrishami auto &valuesMap = permsToValues[perms]; 67200f239e4SArteen Abrishami auto input = transposeOp.getInput1(); 67300f239e4SArteen Abrishami 67400f239e4SArteen Abrishami // The purpose of this reverse iteration 67500f239e4SArteen Abrishami // is to avoid valuesMap invalidation. If it happens, 67600f239e4SArteen Abrishami // something is wrong. 67700f239e4SArteen Abrishami if (!valuesMap.contains(input)) 67800f239e4SArteen Abrishami return signalPassFailure(); 67900f239e4SArteen Abrishami 68000f239e4SArteen Abrishami rewriter.replaceOp(transposeOp, valuesMap.at(input)); 68100f239e4SArteen Abrishami } 68200f239e4SArteen Abrishami 68300f239e4SArteen Abrishami // We can remove all dead code by going in reverse. 68400f239e4SArteen Abrishami // This is because we would remove usages before we 68500f239e4SArteen Abrishami // see the users. 68600f239e4SArteen Abrishami getOperation().walk<WalkOrder::PostOrder, ReverseIterator>( 68700f239e4SArteen Abrishami [&](Operation *op) { 68800f239e4SArteen Abrishami if (isOpTriviallyDead(op)) 68900f239e4SArteen Abrishami rewriter.eraseOp(op); 69000f239e4SArteen Abrishami }); 69100f239e4SArteen Abrishami } 69200f239e4SArteen Abrishami 69300f239e4SArteen Abrishami } // namespace 694