xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp (revision a58e774fba42e13aa00667d644e96b783fc914b4)
100f239e4SArteen Abrishami //===- TosaReduceTransposes.cpp -------------------------------------------===//
200f239e4SArteen Abrishami //
300f239e4SArteen Abrishami // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
400f239e4SArteen Abrishami // See https://llvm.org/LICENSE.txt for license information.
500f239e4SArteen Abrishami // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
600f239e4SArteen Abrishami //
700f239e4SArteen Abrishami //===----------------------------------------------------------------------===//
800f239e4SArteen Abrishami 
900f239e4SArteen Abrishami // ----------
1000f239e4SArteen Abrishami // Motivation:
1100f239e4SArteen Abrishami // ----------
1200f239e4SArteen Abrishami 
1300f239e4SArteen Abrishami // Some legalization pathways introduce redundant tosa.TRANSPOSE
1400f239e4SArteen Abrishami // operations that result in avoidable data movement. For example,
1500f239e4SArteen Abrishami // PyTorch -> TOSA contains a lot of unnecessary transposes due
1600f239e4SArteen Abrishami // to conversions between NCHW and NHWC.
1700f239e4SArteen Abrishami 
1800f239e4SArteen Abrishami // We wish to remove all the ones that we can, since in general
1900f239e4SArteen Abrishami // it is possible to remove the overwhelming majority.
2000f239e4SArteen Abrishami 
2100f239e4SArteen Abrishami // -------------------
2200f239e4SArteen Abrishami // High-Level Overview:
2300f239e4SArteen Abrishami // -------------------
2400f239e4SArteen Abrishami 
2500f239e4SArteen Abrishami // The pass works through the transpose operators in the program. It begins at
2600f239e4SArteen Abrishami // some transpose operator with an associated permutations tensor. It traverses
2700f239e4SArteen Abrishami // upwards through the dependencies of this transpose and verifies that we
2800f239e4SArteen Abrishami // encounter only operators with the TosaElementwiseOperator trait and terminate
2900f239e4SArteen Abrishami // in either constants, reshapes, or transposes.
3000f239e4SArteen Abrishami 
3100f239e4SArteen Abrishami // We then evaluate whether there are any additional restrictions (the
3200f239e4SArteen Abrishami // transposes it terminates in must invert the one we began at, and the reshapes
3300f239e4SArteen Abrishami // must be ones in which we can fold the transpose into), and then we hoist the
3400f239e4SArteen Abrishami // transpose through the intervening operators, folding it at the constants,
3500f239e4SArteen Abrishami // reshapes, and transposes.
3600f239e4SArteen Abrishami 
3700f239e4SArteen Abrishami // Finally, we ensure that we do not need both the transposed form (the form
3800f239e4SArteen Abrishami // that had the transpose hoisted through it) and the untransposed form (which
3900f239e4SArteen Abrishami // it was prior), by analyzing the usages of those dependent operators of a
4000f239e4SArteen Abrishami // given transpose we are attempting to hoist and replace.
4100f239e4SArteen Abrishami 
4200f239e4SArteen Abrishami // If they are such that it would require both forms to be necessary, then we do
4300f239e4SArteen Abrishami // not replace the hoisted transpose, causing the new chain to be dead.
4400f239e4SArteen Abrishami // Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
4500f239e4SArteen Abrishami // chain will ever then be live, resulting in no duplication.
4600f239e4SArteen Abrishami 
4700f239e4SArteen Abrishami // We then perform a simple one-pass DCE, so no canonicalization is necessary.
4800f239e4SArteen Abrishami 
4900f239e4SArteen Abrishami // -----------
5000f239e4SArteen Abrishami // Future Work:
5100f239e4SArteen Abrishami // -----------
5200f239e4SArteen Abrishami 
5300f239e4SArteen Abrishami // (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
5400f239e4SArteen Abrishami // hoisted
5500f239e4SArteen Abrishami //     transposes with different permutation tensors.
5600f239e4SArteen Abrishami 
5700f239e4SArteen Abrishami // (2) Expand the class of foldable upstream ReshapeOp we permit beyond
5800f239e4SArteen Abrishami //     N -> 1x1x...x1xNx1x...x1x1.
5900f239e4SArteen Abrishami 
6000f239e4SArteen Abrishami // (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
6100f239e4SArteen Abrishami //     those that form the identity.
6200f239e4SArteen Abrishami 
6300f239e4SArteen Abrishami // (4) Add support for more instructions besides TosaElementwiseOperator as
6400f239e4SArteen Abrishami //     the intervening ones (for example, the reduce_* operators).
6500f239e4SArteen Abrishami 
6600f239e4SArteen Abrishami // (5) Support hoisting transposes up to an input parameter.
6700f239e4SArteen Abrishami 
6800f239e4SArteen Abrishami //===----------------------------------------------------------------------===//
6900f239e4SArteen Abrishami 
7000f239e4SArteen Abrishami #include "mlir/Dialect/Func/IR/FuncOps.h"
7100f239e4SArteen Abrishami #include "mlir/Dialect/Tosa/IR/TosaOps.h"
7200f239e4SArteen Abrishami #include "mlir/Dialect/Tosa/Transforms/Passes.h"
7300f239e4SArteen Abrishami #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
7400f239e4SArteen Abrishami #include "mlir/IR/Iterators.h"
7500f239e4SArteen Abrishami #include "mlir/IR/Matchers.h"
7600f239e4SArteen Abrishami #include "llvm/ADT/TypeSwitch.h"
7700f239e4SArteen Abrishami #include <memory>
7800f239e4SArteen Abrishami #include <set>
7900f239e4SArteen Abrishami #include <stack>
8000f239e4SArteen Abrishami 
8100f239e4SArteen Abrishami namespace mlir {
8200f239e4SArteen Abrishami namespace tosa {
8300f239e4SArteen Abrishami #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
8400f239e4SArteen Abrishami #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
8500f239e4SArteen Abrishami } // namespace tosa
8600f239e4SArteen Abrishami } // namespace mlir
8700f239e4SArteen Abrishami 
8800f239e4SArteen Abrishami using namespace mlir;
8900f239e4SArteen Abrishami using namespace mlir::tosa;
9000f239e4SArteen Abrishami 
9100f239e4SArteen Abrishami //===----------------------------------------------------------------------===//
9200f239e4SArteen Abrishami // TOSA Reduce Transposes Pass.
9300f239e4SArteen Abrishami //===----------------------------------------------------------------------===//
9400f239e4SArteen Abrishami 
9500f239e4SArteen Abrishami namespace {
9600f239e4SArteen Abrishami 
9700f239e4SArteen Abrishami struct TosaReduceTransposes final
9800f239e4SArteen Abrishami     : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
9900f239e4SArteen Abrishami   void runOnOperation() override;
10000f239e4SArteen Abrishami 
10100f239e4SArteen Abrishami private:
10200f239e4SArteen Abrishami   // This will collect all the data dependencies for the given Operation
10300f239e4SArteen Abrishami   // up to and including ConstOp, ReshapeOp, and TransposeOp.
10400f239e4SArteen Abrishami   bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
10500f239e4SArteen Abrishami   bool convertDependentOps(SetVector<Operation *> &dependentOps,
10600f239e4SArteen Abrishami                            DenseMap<Value, Value> &valuesMap,
10700f239e4SArteen Abrishami                            IRRewriter &rewriter,
10800f239e4SArteen Abrishami                            ArrayRef<int32_t> hoistedPerms);
10900f239e4SArteen Abrishami 
11000f239e4SArteen Abrishami   // Checks if the two permutations, when applied consecutively, result
11100f239e4SArteen Abrishami   // in the identity.
11200f239e4SArteen Abrishami   bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
11300f239e4SArteen Abrishami                                ArrayRef<int32_t> perms2);
11400f239e4SArteen Abrishami 
11500f239e4SArteen Abrishami   // This is meant to apply to operations with the TosaElementwiseOperator
11600f239e4SArteen Abrishami   // trait.
11700f239e4SArteen Abrishami   std::optional<Value>
11800f239e4SArteen Abrishami   buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
11900f239e4SArteen Abrishami                      IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
12000f239e4SArteen Abrishami 
12100f239e4SArteen Abrishami   // This updates valuesMap when we encounter another TransposeOp as a
12200f239e4SArteen Abrishami   // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to
12300f239e4SArteen Abrishami   // this %1 = tosa.transpose %0 <- when tracking back from this
12400f239e4SArteen Abrishami   std::optional<Value>
12500f239e4SArteen Abrishami   buildMappedToValue(TransposeOp transposeOp,
12600f239e4SArteen Abrishami                      const DenseMap<Value, Value> &valuesMap,
12700f239e4SArteen Abrishami                      IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
12800f239e4SArteen Abrishami 
12900f239e4SArteen Abrishami   // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so,
13000f239e4SArteen Abrishami   // it creates new ReshapeOp with that fold.
13100f239e4SArteen Abrishami   std::optional<Value>
13200f239e4SArteen Abrishami   buildMappedToValue(ReshapeOp reshapeOp,
13300f239e4SArteen Abrishami                      const DenseMap<Value, Value> &valuesMap,
13400f239e4SArteen Abrishami                      IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
13500f239e4SArteen Abrishami 
13600f239e4SArteen Abrishami   // We may have something like:
13700f239e4SArteen Abrishami   // %0 = tosa.const
13800f239e4SArteen Abrishami   // %1 = tosa.transpose
13900f239e4SArteen Abrishami   // %2 = tosa.add %0, %1
14000f239e4SArteen Abrishami   // %3 = tosa.transpose %2
14100f239e4SArteen Abrishami   // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
14200f239e4SArteen Abrishami   // in MobilenetV3.
14300f239e4SArteen Abrishami   std::optional<Value>
14400f239e4SArteen Abrishami   buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
14500f239e4SArteen Abrishami                      IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
14600f239e4SArteen Abrishami 
14700f239e4SArteen Abrishami   // Checks which TransposeOp we should "replace", turning their converted
14800f239e4SArteen Abrishami   // chains of ops, through which they were propagated, "live", and the old code
14900f239e4SArteen Abrishami   // "dead." Attempts to avoid doing so when doing so would result in the old
15000f239e4SArteen Abrishami   // code staying "live," resulting in duplication.
15100f239e4SArteen Abrishami   std::set<TransposeOp> getGoodReplacements(
15200f239e4SArteen Abrishami       ArrayRef<int32_t> perms,
15300f239e4SArteen Abrishami       std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
15400f239e4SArteen Abrishami           &transposeInfo);
15500f239e4SArteen Abrishami 
15600f239e4SArteen Abrishami   // Helper function for dependenciesAreValid.
15700f239e4SArteen Abrishami   bool userNotContainedInValidTransposeDependencies(
15800f239e4SArteen Abrishami       Operation *user, std::set<TransposeOp> &validTransposes,
15900f239e4SArteen Abrishami       std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
16000f239e4SArteen Abrishami           &transposeInfo);
16100f239e4SArteen Abrishami 
16200f239e4SArteen Abrishami   // Helper function for getGoodReplacements to check if some TransposeOp's
16300f239e4SArteen Abrishami   // dependencies are OK.
16400f239e4SArteen Abrishami   bool dependenciesAreValid(
16500f239e4SArteen Abrishami       ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
16600f239e4SArteen Abrishami       std::set<TransposeOp> &validTransposes,
16700f239e4SArteen Abrishami       std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
16800f239e4SArteen Abrishami           &transposeInfo);
16900f239e4SArteen Abrishami 
17000f239e4SArteen Abrishami   // Applies perms to the DenseElementsAttr.
17100f239e4SArteen Abrishami   // If it returns std::nullopt, it also triggers pass failure, since verifier
17200f239e4SArteen Abrishami   // guarantees from TOSA are not in place (and otherwise, if used elsewhere,
17300f239e4SArteen Abrishami   // it should fail).
17400f239e4SArteen Abrishami   // This is a basic API and may benefit from refactor into the core MLIR APIs.
17500f239e4SArteen Abrishami   std::optional<DenseElementsAttr>
17600f239e4SArteen Abrishami   transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
17700f239e4SArteen Abrishami };
17800f239e4SArteen Abrishami 
17900f239e4SArteen Abrishami std::optional<DenseElementsAttr>
18000f239e4SArteen Abrishami TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
18100f239e4SArteen Abrishami                                               ArrayRef<int32_t> perms) {
18200f239e4SArteen Abrishami   RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
18300f239e4SArteen Abrishami   RankedTensorType newType =
18400f239e4SArteen Abrishami       RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms),
18500f239e4SArteen Abrishami                             oldType.getElementType());
18600f239e4SArteen Abrishami   size_t rank = oldType.getRank();
18700f239e4SArteen Abrishami 
18800f239e4SArteen Abrishami   // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
18900f239e4SArteen Abrishami   // 0. If not in place, something is very wrong.
190*a58e774fSJack Frankland   if (rank <= 0 || oldType.getNumElements() <= 0) {
19100f239e4SArteen Abrishami     signalPassFailure();
19200f239e4SArteen Abrishami     return std::nullopt;
19300f239e4SArteen Abrishami   }
19400f239e4SArteen Abrishami 
19500f239e4SArteen Abrishami   if (input.isSplat())
19600f239e4SArteen Abrishami     return input.reshape(newType);
19700f239e4SArteen Abrishami 
19800f239e4SArteen Abrishami   // The algorithm is approximately as follows:
19900f239e4SArteen Abrishami   // input: perms, input flat array, input tensor type
20000f239e4SArteen Abrishami   // (1/2) determine the strides of input/output if
20100f239e4SArteen Abrishami   // they were strided in row-major order. (3) adjust the strides for the
20200f239e4SArteen Abrishami   // input to be in the same order of indices as the output is written.
20300f239e4SArteen Abrishami   // (4) process dimension by dimension. example: perms 2, 0, 1; input
20400f239e4SArteen Abrishami   // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
20500f239e4SArteen Abrishami   // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
20600f239e4SArteen Abrishami   // input strides to be as input[i + 12j + 4k] so we may process
20700f239e4SArteen Abrishami   // layer-by-layer.
20800f239e4SArteen Abrishami 
20900f239e4SArteen Abrishami   // Step 1/2: Strides for input. We ignore output since row-major and can just
21000f239e4SArteen Abrishami   // push_back.
21100f239e4SArteen Abrishami 
21200f239e4SArteen Abrishami   SmallVector<int64_t> originalInputStrides(rank);
21300f239e4SArteen Abrishami   originalInputStrides[rank - 1] = 1;
21400f239e4SArteen Abrishami   // index with int64_t to avoid overflow
21500f239e4SArteen Abrishami   for (int64_t i = rank - 2; i >= 0; i--)
21600f239e4SArteen Abrishami     originalInputStrides[i] =
21700f239e4SArteen Abrishami         originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
21800f239e4SArteen Abrishami 
21900f239e4SArteen Abrishami   // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
22000f239e4SArteen Abrishami   // output which is done in row-major order.
22100f239e4SArteen Abrishami 
22200f239e4SArteen Abrishami   SmallVector<int64_t> newInputStrides;
22300f239e4SArteen Abrishami   newInputStrides.reserve(rank);
22400f239e4SArteen Abrishami   for (int32_t v : perms)
22500f239e4SArteen Abrishami     newInputStrides.push_back(originalInputStrides[v]);
22600f239e4SArteen Abrishami 
22700f239e4SArteen Abrishami   // Step 4: Write out the transposed "flat array" dimension by dimension.
22800f239e4SArteen Abrishami 
22900f239e4SArteen Abrishami   auto inputArray = input.getValues<Attribute>();
23000f239e4SArteen Abrishami   SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
23100f239e4SArteen Abrishami   for (size_t i = 0; i < rank; i++)
23200f239e4SArteen Abrishami     boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
23300f239e4SArteen Abrishami 
23400f239e4SArteen Abrishami   SmallVector<Attribute> resultArray;
23500f239e4SArteen Abrishami   resultArray.reserve(inputArray.size());
23600f239e4SArteen Abrishami 
23700f239e4SArteen Abrishami   std::function<void(int64_t,
23800f239e4SArteen Abrishami                      SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
23900f239e4SArteen Abrishami       processTransposeDim = [&](auto accumulatedIndex, auto it) {
24000f239e4SArteen Abrishami         if (it == boundsAndStrides.end()) {
24100f239e4SArteen Abrishami           resultArray.push_back(inputArray[accumulatedIndex]);
24200f239e4SArteen Abrishami           return;
24300f239e4SArteen Abrishami         }
24400f239e4SArteen Abrishami 
24500f239e4SArteen Abrishami         for (int64_t i = 0; i < it->first; i++) {
24600f239e4SArteen Abrishami           int64_t j = accumulatedIndex + i * it->second;
24700f239e4SArteen Abrishami           processTransposeDim(j, it + 1);
24800f239e4SArteen Abrishami         }
24900f239e4SArteen Abrishami       };
25000f239e4SArteen Abrishami 
25100f239e4SArteen Abrishami   processTransposeDim(0, boundsAndStrides.begin());
25200f239e4SArteen Abrishami 
25300f239e4SArteen Abrishami   return DenseElementsAttr::get(newType, resultArray);
25400f239e4SArteen Abrishami }
25500f239e4SArteen Abrishami 
25600f239e4SArteen Abrishami // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
25700f239e4SArteen Abrishami // as the sources of the data dependencies, and TosaElementWiseOperator
25800f239e4SArteen Abrishami // after that, if the function returns true.
25900f239e4SArteen Abrishami bool TosaReduceTransposes::collectFanIn(Operation *op,
26000f239e4SArteen Abrishami                                         SetVector<Operation *> &collected) {
26100f239e4SArteen Abrishami   // Can occur if defined through the parameter to a func.func.
26200f239e4SArteen Abrishami   if (!op)
26300f239e4SArteen Abrishami     return false;
26400f239e4SArteen Abrishami 
26500f239e4SArteen Abrishami   if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
26600f239e4SArteen Abrishami     return false;
26700f239e4SArteen Abrishami 
26800f239e4SArteen Abrishami   // Prevent extra work if already seen.
26900f239e4SArteen Abrishami   if (collected.contains(op))
27000f239e4SArteen Abrishami     return true;
27100f239e4SArteen Abrishami 
27200f239e4SArteen Abrishami   // Throw it out so later don't have to deal with this.
27300f239e4SArteen Abrishami   if (op->getNumResults() != 1 ||
27400f239e4SArteen Abrishami       !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
27500f239e4SArteen Abrishami     return false;
27600f239e4SArteen Abrishami 
27700f239e4SArteen Abrishami   // We don't wish to traverse up a ReshapeOp, since generally we can't
27800f239e4SArteen Abrishami   // propagate a TransposeOp through it.  TransposeOp, ReshapeOp, ConstOp
27900f239e4SArteen Abrishami   // will have no in-edges in the data dependency graph we construct for
28000f239e4SArteen Abrishami   // the downstream TransposeOp.
28100f239e4SArteen Abrishami   if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
28200f239e4SArteen Abrishami       !llvm::isa<tosa::ConstOp>(op)) {
28300f239e4SArteen Abrishami 
28400f239e4SArteen Abrishami     if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
28500f239e4SArteen Abrishami       return false;
28600f239e4SArteen Abrishami 
28700f239e4SArteen Abrishami     for (Value operand : op->getOperands())
28800f239e4SArteen Abrishami       // If this is a problem in future, think about alternatives to recursion.
28900f239e4SArteen Abrishami       if (!collectFanIn(operand.getDefiningOp(), collected))
29000f239e4SArteen Abrishami         return false;
29100f239e4SArteen Abrishami   }
29200f239e4SArteen Abrishami 
29300f239e4SArteen Abrishami   // Insert in topological order.
29400f239e4SArteen Abrishami   collected.insert(op);
29500f239e4SArteen Abrishami 
29600f239e4SArteen Abrishami   return true;
29700f239e4SArteen Abrishami }
29800f239e4SArteen Abrishami 
29900f239e4SArteen Abrishami // Assuming that due to the verification of TransposeOp perms arrays are
30000f239e4SArteen Abrishami // permutations of 0 - perms.size() - 1.
30100f239e4SArteen Abrishami bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
30200f239e4SArteen Abrishami                                                    ArrayRef<int32_t> perms2) {
30300f239e4SArteen Abrishami   if (perms1.size() != perms2.size())
30400f239e4SArteen Abrishami     return false;
30500f239e4SArteen Abrishami   int32_t n = perms1.size();
30600f239e4SArteen Abrishami   for (int32_t i = 0; i < n; i++)
30700f239e4SArteen Abrishami     if (perms2[perms1[i]] != i)
30800f239e4SArteen Abrishami       return false;
30900f239e4SArteen Abrishami   return true;
31000f239e4SArteen Abrishami }
31100f239e4SArteen Abrishami 
31200f239e4SArteen Abrishami // Primary overload for those with TosaElementwiseOperator trait.
31300f239e4SArteen Abrishami // The other ones handle the case of the operations that occur at the
31400f239e4SArteen Abrishami // roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
31500f239e4SArteen Abrishami std::optional<Value> TosaReduceTransposes::buildMappedToValue(
31600f239e4SArteen Abrishami     Operation *op, const DenseMap<Value, Value> &valuesMap,
31700f239e4SArteen Abrishami     IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
31800f239e4SArteen Abrishami   if (op->getNumResults() != 1 ||
31900f239e4SArteen Abrishami       !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
32000f239e4SArteen Abrishami     return std::nullopt;
32100f239e4SArteen Abrishami 
32200f239e4SArteen Abrishami   auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
32300f239e4SArteen Abrishami   SmallVector<Value, 3> operands;
32400f239e4SArteen Abrishami   for (Value v : op->getOperands()) {
32500f239e4SArteen Abrishami     if (valuesMap.contains(v)) {
32600f239e4SArteen Abrishami       operands.push_back(valuesMap.at(v));
32700f239e4SArteen Abrishami     } else {
32800f239e4SArteen Abrishami       return std::nullopt;
32900f239e4SArteen Abrishami     }
33000f239e4SArteen Abrishami   }
33100f239e4SArteen Abrishami 
33200f239e4SArteen Abrishami   // Conceptually, we propagate the hoisted TransposeOp through
33300f239e4SArteen Abrishami   // these interveaning operations. For example,
33400f239e4SArteen Abrishami 
33500f239e4SArteen Abrishami   // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
33600f239e4SArteen Abrishami   // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
33700f239e4SArteen Abrishami   // tensor<3x2xi32>
33800f239e4SArteen Abrishami 
33900f239e4SArteen Abrishami   // becomes:
34000f239e4SArteen Abrishami   // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) ->
34100f239e4SArteen Abrishami   // tensor<3x2xi32>
34200f239e4SArteen Abrishami   // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>)
34300f239e4SArteen Abrishami 
34400f239e4SArteen Abrishami   // We construct this new tosa.clamp here, but it doesn't
34500f239e4SArteen Abrishami   // turn "live" until the transpose being hoisted through this chain
34600f239e4SArteen Abrishami   // is replaced with the proper value from the new chain.
34700f239e4SArteen Abrishami 
34800f239e4SArteen Abrishami   return rewriter
34900f239e4SArteen Abrishami       .create(op->getLoc(), op->getName().getIdentifier(), operands,
35000f239e4SArteen Abrishami               RankedTensorType::get(
35100f239e4SArteen Abrishami                   applyTOSAPermutation(resultType.getShape(), hoistedPerms),
35200f239e4SArteen Abrishami                   resultType.getElementType()),
35300f239e4SArteen Abrishami               op->getAttrs())
35400f239e4SArteen Abrishami       ->getResult(0);
35500f239e4SArteen Abrishami }
35600f239e4SArteen Abrishami 
35700f239e4SArteen Abrishami std::optional<Value> TosaReduceTransposes::buildMappedToValue(
35800f239e4SArteen Abrishami     TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
35900f239e4SArteen Abrishami     IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
36000f239e4SArteen Abrishami   SmallVector<int32_t> perms;
36100f239e4SArteen Abrishami   if (failed(transposeOp.getConstantPerms(perms)) ||
36200f239e4SArteen Abrishami       !areInvolutionTransposes(hoistedPerms, perms))
36300f239e4SArteen Abrishami     return std::nullopt;
36400f239e4SArteen Abrishami   return transposeOp.getInput1();
36500f239e4SArteen Abrishami }
36600f239e4SArteen Abrishami 
36700f239e4SArteen Abrishami std::optional<Value> TosaReduceTransposes::buildMappedToValue(
36800f239e4SArteen Abrishami     ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
36900f239e4SArteen Abrishami     IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
37000f239e4SArteen Abrishami   auto reshapeOutput = reshapeOp.getOutput();
37100f239e4SArteen Abrishami   auto reshapeInputType =
37200f239e4SArteen Abrishami       llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
37300f239e4SArteen Abrishami   auto reshapeInputShape = reshapeInputType.getShape();
37400f239e4SArteen Abrishami   // want reshape N -> 1x1x...x1xNx1x...x1x1
37500f239e4SArteen Abrishami   if (!reshapeInputType || reshapeInputShape.size() != 1)
37600f239e4SArteen Abrishami     return std::nullopt;
37700f239e4SArteen Abrishami   auto reshapeOutputType =
37800f239e4SArteen Abrishami       llvm::cast<RankedTensorType>(reshapeOutput.getType());
37900f239e4SArteen Abrishami 
38000f239e4SArteen Abrishami   // Instead of inserting a TransposeOp here, we check if we can fold it into
38100f239e4SArteen Abrishami   // the ReshapeOp. There is more complex cases where this is possible, and
38200f239e4SArteen Abrishami   // this check can be extended.
38300f239e4SArteen Abrishami 
38400f239e4SArteen Abrishami   // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
38500f239e4SArteen Abrishami   auto shape = reshapeOutputType.getShape();
38600f239e4SArteen Abrishami   size_t ones = llvm::count(shape, 1);
38700f239e4SArteen Abrishami   // N == 1 and N != 1
38800f239e4SArteen Abrishami   if (ones != shape.size() - 1 &&
38900f239e4SArteen Abrishami       !(ones == shape.size() && reshapeInputShape[0] == 1))
39000f239e4SArteen Abrishami     return std::nullopt;
39100f239e4SArteen Abrishami 
39200f239e4SArteen Abrishami   // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
39300f239e4SArteen Abrishami   auto foldedReshape = rewriter.create<ReshapeOp>(
39400f239e4SArteen Abrishami       reshapeOp.getLoc(),
39500f239e4SArteen Abrishami       RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
39600f239e4SArteen Abrishami                             reshapeOutputType.getElementType()),
39700f239e4SArteen Abrishami       reshapeOp.getInput1(),
39800f239e4SArteen Abrishami       rewriter.getDenseI64ArrayAttr(
39900f239e4SArteen Abrishami           applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
40000f239e4SArteen Abrishami   return foldedReshape->getResult(0);
40100f239e4SArteen Abrishami }
40200f239e4SArteen Abrishami 
40300f239e4SArteen Abrishami std::optional<Value> TosaReduceTransposes::buildMappedToValue(
40400f239e4SArteen Abrishami     ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
40500f239e4SArteen Abrishami     IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
40600f239e4SArteen Abrishami   auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
40700f239e4SArteen Abrishami   if (!denseAttr)
40800f239e4SArteen Abrishami     return std::nullopt;
40900f239e4SArteen Abrishami   auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
41000f239e4SArteen Abrishami   if (!maybeNewDenseAttr.has_value())
41100f239e4SArteen Abrishami     return std::nullopt;
41200f239e4SArteen Abrishami   auto newDenseAttr = maybeNewDenseAttr.value();
41300f239e4SArteen Abrishami   auto newConstOp = rewriter.create<ConstOp>(
41400f239e4SArteen Abrishami       constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
41500f239e4SArteen Abrishami   return newConstOp->getResult(0);
41600f239e4SArteen Abrishami }
41700f239e4SArteen Abrishami 
41800f239e4SArteen Abrishami bool TosaReduceTransposes::convertDependentOps(
41900f239e4SArteen Abrishami     SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
42000f239e4SArteen Abrishami     IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
42100f239e4SArteen Abrishami 
42200f239e4SArteen Abrishami   for (Operation *op : dependentOps) {
42300f239e4SArteen Abrishami     if (!op || op->getNumResults() != 1)
42400f239e4SArteen Abrishami       return false;
42500f239e4SArteen Abrishami 
42600f239e4SArteen Abrishami     Value priorValue = op->getResult(0);
42700f239e4SArteen Abrishami 
42800f239e4SArteen Abrishami     // It's possible on a prior transposeOp we had the same dependency and
42900f239e4SArteen Abrishami     // already resolved it.
43000f239e4SArteen Abrishami     if (valuesMap.contains(priorValue))
43100f239e4SArteen Abrishami       continue;
43200f239e4SArteen Abrishami 
43300f239e4SArteen Abrishami     // Keep converted ops close to the original.
43400f239e4SArteen Abrishami     rewriter.setInsertionPointAfter(op);
43500f239e4SArteen Abrishami 
43600f239e4SArteen Abrishami     std::optional<Value> maybeValue =
43700f239e4SArteen Abrishami         llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
43800f239e4SArteen Abrishami             .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) {
43900f239e4SArteen Abrishami               return buildMappedToValue(transposeOp, valuesMap, rewriter,
44000f239e4SArteen Abrishami                                         hoistedPerms);
44100f239e4SArteen Abrishami             })
44200f239e4SArteen Abrishami             .Default([&](Operation *op) {
44300f239e4SArteen Abrishami               return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
44400f239e4SArteen Abrishami             });
44500f239e4SArteen Abrishami 
44600f239e4SArteen Abrishami     if (!maybeValue.has_value())
44700f239e4SArteen Abrishami       return false;
44800f239e4SArteen Abrishami 
44900f239e4SArteen Abrishami     valuesMap[priorValue] = maybeValue.value();
45000f239e4SArteen Abrishami   }
45100f239e4SArteen Abrishami 
45200f239e4SArteen Abrishami   return true;
45300f239e4SArteen Abrishami }
45400f239e4SArteen Abrishami 
45500f239e4SArteen Abrishami bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
45600f239e4SArteen Abrishami     Operation *user, std::set<TransposeOp> &validTransposes,
45700f239e4SArteen Abrishami     std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
45800f239e4SArteen Abrishami         &transposeInfo) {
45900f239e4SArteen Abrishami   return llvm::none_of(
46000f239e4SArteen Abrishami       transposeInfo,
46100f239e4SArteen Abrishami       [&validTransposes,
46200f239e4SArteen Abrishami        user](const std::pair<TransposeOp, SetVector<Operation *>> &info) {
46300f239e4SArteen Abrishami         const auto &[transposeOp, dependentOps] = info;
46400f239e4SArteen Abrishami         return validTransposes.count(transposeOp) &&
46500f239e4SArteen Abrishami                dependentOps.contains(user);
46600f239e4SArteen Abrishami       });
46700f239e4SArteen Abrishami }
46800f239e4SArteen Abrishami 
46900f239e4SArteen Abrishami // Dependencies are valid for an operation if none of them occur outside
47000f239e4SArteen Abrishami // of the proper fan-in cones of the hoisted TransposeOp with the same perms
47100f239e4SArteen Abrishami // that we can replace. Described in more detail within.
47200f239e4SArteen Abrishami bool TosaReduceTransposes::dependenciesAreValid(
47300f239e4SArteen Abrishami     ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
47400f239e4SArteen Abrishami     std::set<TransposeOp> &validTransposes,
47500f239e4SArteen Abrishami     std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
47600f239e4SArteen Abrishami         &transposeInfo) {
47700f239e4SArteen Abrishami   for (Operation *op : dependentOps) {
47800f239e4SArteen Abrishami 
47900f239e4SArteen Abrishami     // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
48000f239e4SArteen Abrishami     // This can be changed later if we find the memory impact is too high.
48100f239e4SArteen Abrishami     if (llvm::isa<ConstOp>(op))
48200f239e4SArteen Abrishami       continue;
48300f239e4SArteen Abrishami 
48400f239e4SArteen Abrishami     for (OpOperand &use : op->getUses()) {
48500f239e4SArteen Abrishami       // Want the uses to be (1) contained in the dependentOps of other
48600f239e4SArteen Abrishami       // validTransposes, or (2) to be directly used in a TransposeOp with the
48700f239e4SArteen Abrishami       // same perms. For (2) it means the fan-in is a subset of our
48800f239e4SArteen Abrishami       // dependentOps, so it is also a validTranspose that will eventually be
48900f239e4SArteen Abrishami       // replaced.
49000f239e4SArteen Abrishami       Operation *user = use.getOwner();
49100f239e4SArteen Abrishami       if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
49200f239e4SArteen Abrishami         SmallVector<int32_t> otherPerms;
49300f239e4SArteen Abrishami 
49400f239e4SArteen Abrishami         // Can later think about cases where transpose -> transpose
49500f239e4SArteen Abrishami         // or reshape -> transpose, where the transposes are not necessarily
49600f239e4SArteen Abrishami         // the same perms as the hoisted, if implementing a more general
49700f239e4SArteen Abrishami         // transform. These could be permitted.
49800f239e4SArteen Abrishami         if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
49900f239e4SArteen Abrishami             !llvm::equal(perms, otherPerms))
50000f239e4SArteen Abrishami           return false;
50100f239e4SArteen Abrishami       } else if (userNotContainedInValidTransposeDependencies(
50200f239e4SArteen Abrishami                      user, validTransposes, transposeInfo)) {
50300f239e4SArteen Abrishami         return false;
50400f239e4SArteen Abrishami       }
50500f239e4SArteen Abrishami     }
50600f239e4SArteen Abrishami   }
50700f239e4SArteen Abrishami 
50800f239e4SArteen Abrishami   return true;
50900f239e4SArteen Abrishami }
51000f239e4SArteen Abrishami 
51100f239e4SArteen Abrishami // Getting the set of TransposeOp that we can replace without causing
51200f239e4SArteen Abrishami // the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
51300f239e4SArteen Abrishami // dead code. This is done by iterating the set until convergence, since
51400f239e4SArteen Abrishami // if you are used outside your own fan-in cone, it's possible to be used
51500f239e4SArteen Abrishami // in another fan-in cone of a TransposeOp that is being replaced -- unless
51600f239e4SArteen Abrishami // we find that that one has a usage outside of it too.
51700f239e4SArteen Abrishami std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
51800f239e4SArteen Abrishami     ArrayRef<int32_t> perms,
51900f239e4SArteen Abrishami     std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
52000f239e4SArteen Abrishami         &transposeInfo) {
52100f239e4SArteen Abrishami   // Initially, we assume they are all good to replace,
52200f239e4SArteen Abrishami   // and we whittle them down based on our criteria.
52300f239e4SArteen Abrishami   std::set<TransposeOp> ableToReplace;
52400f239e4SArteen Abrishami   for (const auto &[transposeOp, _] : transposeInfo)
52500f239e4SArteen Abrishami     ableToReplace.insert(transposeOp);
52600f239e4SArteen Abrishami 
52700f239e4SArteen Abrishami   bool gotRid;
52800f239e4SArteen Abrishami   do {
52900f239e4SArteen Abrishami     gotRid = false;
53000f239e4SArteen Abrishami     for (const auto &[transposeOp, dependentOps] : transposeInfo) {
53100f239e4SArteen Abrishami       // We don't care about it. Already invalidated.
53200f239e4SArteen Abrishami       if (!ableToReplace.count(transposeOp))
53300f239e4SArteen Abrishami         continue;
53400f239e4SArteen Abrishami 
53500f239e4SArteen Abrishami       // Check for validity.
53600f239e4SArteen Abrishami       if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
53700f239e4SArteen Abrishami                                 transposeInfo)) {
53800f239e4SArteen Abrishami         ableToReplace.erase(transposeOp);
53900f239e4SArteen Abrishami         gotRid = true;
54000f239e4SArteen Abrishami         break;
54100f239e4SArteen Abrishami       }
54200f239e4SArteen Abrishami     }
54300f239e4SArteen Abrishami 
54400f239e4SArteen Abrishami   } while (gotRid);
54500f239e4SArteen Abrishami 
54600f239e4SArteen Abrishami   return ableToReplace;
54700f239e4SArteen Abrishami }
54800f239e4SArteen Abrishami 
54900f239e4SArteen Abrishami void TosaReduceTransposes::runOnOperation() {
55000f239e4SArteen Abrishami   // We want to operate only within a single block.
55100f239e4SArteen Abrishami   if (!getOperation().getRegion().hasOneBlock())
55200f239e4SArteen Abrishami     return;
55300f239e4SArteen Abrishami 
55400f239e4SArteen Abrishami   IRRewriter rewriter(&getContext());
55500f239e4SArteen Abrishami   // For each perms, maintain a mapping for converted ops, avoid duplication.
55600f239e4SArteen Abrishami   DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues;
55700f239e4SArteen Abrishami   // For each perms, we keep track of which TransposeOp are eligible
55800f239e4SArteen Abrishami   // for replacement alongside their dependentOps.
55900f239e4SArteen Abrishami   DenseMap<ArrayRef<int32_t>,
56000f239e4SArteen Abrishami            std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
56100f239e4SArteen Abrishami       permsToTransposeInfo;
56200f239e4SArteen Abrishami 
56300f239e4SArteen Abrishami   // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
56400f239e4SArteen Abrishami   // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
56500f239e4SArteen Abrishami   // since no guarantee of smallness.
56600f239e4SArteen Abrishami   std::vector<SmallVector<int32_t>> collectedPerms;
56700f239e4SArteen Abrishami 
56800f239e4SArteen Abrishami   // This keeps track of the order across all eligible-for-replacement
56900f239e4SArteen Abrishami   // TransposeOp and their perms, a necessity for the final replacements.
57000f239e4SArteen Abrishami   std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
57100f239e4SArteen Abrishami 
57200f239e4SArteen Abrishami   // We want to reserve the space up front, since SmallVector stores some data
57300f239e4SArteen Abrishami   // internally and the ArrayRef can reference that, which we don't want to get
57400f239e4SArteen Abrishami   // invalidated.
57500f239e4SArteen Abrishami   size_t expectedMaxPerms = 0;
57600f239e4SArteen Abrishami   getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
57700f239e4SArteen Abrishami   collectedPerms.reserve(expectedMaxPerms);
57800f239e4SArteen Abrishami 
57900f239e4SArteen Abrishami   getOperation().walk([&](TransposeOp transposeOp) {
58000f239e4SArteen Abrishami     SetVector<Operation *> dependentOps;
58100f239e4SArteen Abrishami     collectedPerms.emplace_back();
58200f239e4SArteen Abrishami     SmallVector<int32_t> &perms = collectedPerms.back();
58300f239e4SArteen Abrishami 
58400f239e4SArteen Abrishami     // Dynamic shapes are OK, but the incompatible ones will be rejected later.
58500f239e4SArteen Abrishami     auto input = transposeOp.getInput1();
58600f239e4SArteen Abrishami     auto output = transposeOp.getOutput();
58700f239e4SArteen Abrishami 
58800f239e4SArteen Abrishami     // However, we don't support unranked tensors.
58900f239e4SArteen Abrishami     if (!llvm::isa<RankedTensorType>(input.getType()) ||
59000f239e4SArteen Abrishami         !llvm::isa<RankedTensorType>(output.getType()))
59100f239e4SArteen Abrishami       return;
59200f239e4SArteen Abrishami 
59300f239e4SArteen Abrishami     // No transformation when transpose permutation non-constant.
59400f239e4SArteen Abrishami     if (failed(transposeOp.getConstantPerms(perms)))
59500f239e4SArteen Abrishami       return;
59600f239e4SArteen Abrishami 
59700f239e4SArteen Abrishami     // We let --canonicalize deal with identity transpose.
59800f239e4SArteen Abrishami     if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
59900f239e4SArteen Abrishami       return;
60000f239e4SArteen Abrishami 
60100f239e4SArteen Abrishami     // Can fail if some set of basic invariants is not met that we want to
60200f239e4SArteen Abrishami     // perform our conversions.
60300f239e4SArteen Abrishami     if (!collectFanIn(input.getDefiningOp(), dependentOps))
60400f239e4SArteen Abrishami       return;
60500f239e4SArteen Abrishami 
60600f239e4SArteen Abrishami     // Want to associate valuesMap for already converted of the same perms,
60700f239e4SArteen Abrishami     // since it's possible multiple hoisted transposes w/ different perms
60800f239e4SArteen Abrishami     // converge on an op, which would result in different transformations.
60900f239e4SArteen Abrishami     DenseMap<Value, Value> &valuesMap = permsToValues[perms];
61000f239e4SArteen Abrishami 
61100f239e4SArteen Abrishami     // Attempt to perform the conversions and placements into IR
61200f239e4SArteen Abrishami     // without turning inserted code "live". Also fills out valuesMap.
61300f239e4SArteen Abrishami     // Fails if there is an intermediary we do not support.
61400f239e4SArteen Abrishami     if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
61500f239e4SArteen Abrishami       // Some additional operations may have been inserted, but will be
61600f239e4SArteen Abrishami       // removed by dead code elimination.
61700f239e4SArteen Abrishami       return;
61800f239e4SArteen Abrishami 
61900f239e4SArteen Abrishami     // This should not happen. If it does -- it's unexpected,
62000f239e4SArteen Abrishami     // so we fail the pass.
62100f239e4SArteen Abrishami     if (!valuesMap.contains(input))
62200f239e4SArteen Abrishami       return signalPassFailure();
62300f239e4SArteen Abrishami 
62400f239e4SArteen Abrishami     // It's possible the types are not compatible (because of dynamic shapes),
62500f239e4SArteen Abrishami     // and in these cases, want to resolve dynamic shapes before running the
62600f239e4SArteen Abrishami     // pass.
62700f239e4SArteen Abrishami     if (output.getType() != valuesMap.at(input).getType())
62800f239e4SArteen Abrishami       return;
62900f239e4SArteen Abrishami 
63000f239e4SArteen Abrishami     auto &transposeInfo = permsToTransposeInfo[perms];
63100f239e4SArteen Abrishami 
63200f239e4SArteen Abrishami     // In general, we might also want to introduce "newDependentOps"
63300f239e4SArteen Abrishami     // if there are new usages that don't fall inside the original fan-ins
63400f239e4SArteen Abrishami     // (like the TransposeOp we insert for ReshapeOp),
63500f239e4SArteen Abrishami     // but in this case, that is specialized enough and overlaps
63600f239e4SArteen Abrishami     // with another direct-use TransposeOp case we need to cover anyway.
63700f239e4SArteen Abrishami     transposeInfo.push_back({transposeOp, dependentOps});
63800f239e4SArteen Abrishami 
63900f239e4SArteen Abrishami     // This is for the final replacement across all transposes.
64000f239e4SArteen Abrishami     totalTransposeOrder.push({transposeOp, perms});
64100f239e4SArteen Abrishami   });
64200f239e4SArteen Abrishami 
64300f239e4SArteen Abrishami   // We want to do a full fan-in analysis on a perms-level,
64400f239e4SArteen Abrishami   // since if we do it on a multi-perms level, and they share (due to a shared
64500f239e4SArteen Abrishami   // dependency on a Reshape) then we would also get duplicate ops.
64600f239e4SArteen Abrishami   // Const is special cased.
64700f239e4SArteen Abrishami   std::set<TransposeOp> ableToReplace;
64800f239e4SArteen Abrishami   for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
64900f239e4SArteen Abrishami     // Gives us back replacements that would never result in any duplicate
65000f239e4SArteen Abrishami     // operations being inserted by us in the IR (i.e, our goal is only to
65100f239e4SArteen Abrishami     // remove transposes, and not create a "new chain" to do so, but replace
65200f239e4SArteen Abrishami     // the existing chains).
65300f239e4SArteen Abrishami     // Ideally, --canonicalize is run before this pass, since it helps this
65400f239e4SArteen Abrishami     // analysis by removing dead code to allow more potentially acceptable
65500f239e4SArteen Abrishami     // transformations.
65600f239e4SArteen Abrishami     auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
65700f239e4SArteen Abrishami     ableToReplace.insert(goodReplacementsForPerms.begin(),
65800f239e4SArteen Abrishami                          goodReplacementsForPerms.end());
65900f239e4SArteen Abrishami   }
66000f239e4SArteen Abrishami 
66100f239e4SArteen Abrishami   // We want to do replacement across all transposes
66200f239e4SArteen Abrishami   // in reverse order, due to invalidation of valuesMap mappings
66300f239e4SArteen Abrishami   // if we did it otherwise.
66400f239e4SArteen Abrishami   while (!totalTransposeOrder.empty()) {
66500f239e4SArteen Abrishami     auto [transposeOp, perms] = totalTransposeOrder.top();
66600f239e4SArteen Abrishami     totalTransposeOrder.pop();
66700f239e4SArteen Abrishami 
66800f239e4SArteen Abrishami     if (ableToReplace.count(transposeOp) == 0)
66900f239e4SArteen Abrishami       continue;
67000f239e4SArteen Abrishami 
67100f239e4SArteen Abrishami     auto &valuesMap = permsToValues[perms];
67200f239e4SArteen Abrishami     auto input = transposeOp.getInput1();
67300f239e4SArteen Abrishami 
67400f239e4SArteen Abrishami     // The purpose of this reverse iteration
67500f239e4SArteen Abrishami     // is to avoid valuesMap invalidation. If it happens,
67600f239e4SArteen Abrishami     // something is wrong.
67700f239e4SArteen Abrishami     if (!valuesMap.contains(input))
67800f239e4SArteen Abrishami       return signalPassFailure();
67900f239e4SArteen Abrishami 
68000f239e4SArteen Abrishami     rewriter.replaceOp(transposeOp, valuesMap.at(input));
68100f239e4SArteen Abrishami   }
68200f239e4SArteen Abrishami 
68300f239e4SArteen Abrishami   // We can remove all dead code by going in reverse.
68400f239e4SArteen Abrishami   // This is because we would remove usages before we
68500f239e4SArteen Abrishami   // see the users.
68600f239e4SArteen Abrishami   getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
68700f239e4SArteen Abrishami       [&](Operation *op) {
68800f239e4SArteen Abrishami         if (isOpTriviallyDead(op))
68900f239e4SArteen Abrishami           rewriter.eraseOp(op);
69000f239e4SArteen Abrishami       });
69100f239e4SArteen Abrishami }
69200f239e4SArteen Abrishami 
69300f239e4SArteen Abrishami } // namespace
694