1 //===- TosaReduceTransposes.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // ---------- 10 // Motivation: 11 // ---------- 12 13 // Some legalization pathways introduce redundant tosa.TRANSPOSE 14 // operations that result in avoidable data movement. For example, 15 // PyTorch -> TOSA contains a lot of unnecessary transposes due 16 // to conversions between NCHW and NHWC. 17 18 // We wish to remove all the ones that we can, since in general 19 // it is possible to remove the overwhelming majority. 20 21 // ------------------- 22 // High-Level Overview: 23 // ------------------- 24 25 // The pass works through the transpose operators in the program. It begins at 26 // some transpose operator with an associated permutations tensor. It traverses 27 // upwards through the dependencies of this transpose and verifies that we 28 // encounter only operators with the TosaElementwiseOperator trait and terminate 29 // in either constants, reshapes, or transposes. 30 31 // We then evaluate whether there are any additional restrictions (the 32 // transposes it terminates in must invert the one we began at, and the reshapes 33 // must be ones in which we can fold the transpose into), and then we hoist the 34 // transpose through the intervening operators, folding it at the constants, 35 // reshapes, and transposes. 36 37 // Finally, we ensure that we do not need both the transposed form (the form 38 // that had the transpose hoisted through it) and the untransposed form (which 39 // it was prior), by analyzing the usages of those dependent operators of a 40 // given transpose we are attempting to hoist and replace. 41 42 // If they are such that it would require both forms to be necessary, then we do 43 // not replace the hoisted transpose, causing the new chain to be dead. 44 // Otherwise, we do and the old chain (untransposed form) becomes dead. Only one 45 // chain will ever then be live, resulting in no duplication. 46 47 // We then perform a simple one-pass DCE, so no canonicalization is necessary. 48 49 // ----------- 50 // Future Work: 51 // ----------- 52 53 // (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across 54 // hoisted 55 // transposes with different permutation tensors. 56 57 // (2) Expand the class of foldable upstream ReshapeOp we permit beyond 58 // N -> 1x1x...x1xNx1x...x1x1. 59 60 // (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond 61 // those that form the identity. 62 63 // (4) Add support for more instructions besides TosaElementwiseOperator as 64 // the intervening ones (for example, the reduce_* operators). 65 66 // (5) Support hoisting transposes up to an input parameter. 67 68 //===----------------------------------------------------------------------===// 69 70 #include "mlir/Dialect/Func/IR/FuncOps.h" 71 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 72 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 73 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" 74 #include "mlir/IR/Iterators.h" 75 #include "mlir/IR/Matchers.h" 76 #include "llvm/ADT/TypeSwitch.h" 77 #include <memory> 78 #include <set> 79 #include <stack> 80 81 namespace mlir { 82 namespace tosa { 83 #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES 84 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 85 } // namespace tosa 86 } // namespace mlir 87 88 using namespace mlir; 89 using namespace mlir::tosa; 90 91 //===----------------------------------------------------------------------===// 92 // TOSA Reduce Transposes Pass. 93 //===----------------------------------------------------------------------===// 94 95 namespace { 96 97 struct TosaReduceTransposes final 98 : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> { 99 void runOnOperation() override; 100 101 private: 102 // This will collect all the data dependencies for the given Operation 103 // up to and including ConstOp, ReshapeOp, and TransposeOp. 104 bool collectFanIn(Operation *op, SetVector<Operation *> &collected); 105 bool convertDependentOps(SetVector<Operation *> &dependentOps, 106 DenseMap<Value, Value> &valuesMap, 107 IRRewriter &rewriter, 108 ArrayRef<int32_t> hoistedPerms); 109 110 // Checks if the two permutations, when applied consecutively, result 111 // in the identity. 112 bool areInvolutionTransposes(ArrayRef<int32_t> perms1, 113 ArrayRef<int32_t> perms2); 114 115 // This is meant to apply to operations with the TosaElementwiseOperator 116 // trait. 117 std::optional<Value> 118 buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap, 119 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); 120 121 // This updates valuesMap when we encounter another TransposeOp as a 122 // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to 123 // this %1 = tosa.transpose %0 <- when tracking back from this 124 std::optional<Value> 125 buildMappedToValue(TransposeOp transposeOp, 126 const DenseMap<Value, Value> &valuesMap, 127 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); 128 129 // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so, 130 // it creates new ReshapeOp with that fold. 131 std::optional<Value> 132 buildMappedToValue(ReshapeOp reshapeOp, 133 const DenseMap<Value, Value> &valuesMap, 134 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); 135 136 // We may have something like: 137 // %0 = tosa.const 138 // %1 = tosa.transpose 139 // %2 = tosa.add %0, %1 140 // %3 = tosa.transpose %2 141 // that --tosa-layerwise-const-fold wouldn't handle. This use shows up 142 // in MobilenetV3. 143 std::optional<Value> 144 buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap, 145 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms); 146 147 // Checks which TransposeOp we should "replace", turning their converted 148 // chains of ops, through which they were propagated, "live", and the old code 149 // "dead." Attempts to avoid doing so when doing so would result in the old 150 // code staying "live," resulting in duplication. 151 std::set<TransposeOp> getGoodReplacements( 152 ArrayRef<int32_t> perms, 153 std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 154 &transposeInfo); 155 156 // Helper function for dependenciesAreValid. 157 bool userNotContainedInValidTransposeDependencies( 158 Operation *user, std::set<TransposeOp> &validTransposes, 159 std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 160 &transposeInfo); 161 162 // Helper function for getGoodReplacements to check if some TransposeOp's 163 // dependencies are OK. 164 bool dependenciesAreValid( 165 ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps, 166 std::set<TransposeOp> &validTransposes, 167 std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 168 &transposeInfo); 169 170 // Applies perms to the DenseElementsAttr. 171 // If it returns std::nullopt, it also triggers pass failure, since verifier 172 // guarantees from TOSA are not in place (and otherwise, if used elsewhere, 173 // it should fail). 174 // This is a basic API and may benefit from refactor into the core MLIR APIs. 175 std::optional<DenseElementsAttr> 176 transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms); 177 }; 178 179 std::optional<DenseElementsAttr> 180 TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input, 181 ArrayRef<int32_t> perms) { 182 RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType()); 183 RankedTensorType newType = 184 RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms), 185 oldType.getElementType()); 186 size_t rank = oldType.getRank(); 187 188 // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension 189 // 0. If not in place, something is very wrong. 190 if (rank <= 0 || oldType.getNumElements() <= 0) { 191 signalPassFailure(); 192 return std::nullopt; 193 } 194 195 if (input.isSplat()) 196 return input.reshape(newType); 197 198 // The algorithm is approximately as follows: 199 // input: perms, input flat array, input tensor type 200 // (1/2) determine the strides of input/output if 201 // they were strided in row-major order. (3) adjust the strides for the 202 // input to be in the same order of indices as the output is written. 203 // (4) process dimension by dimension. example: perms 2, 0, 1; input 204 // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] = 205 // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust 206 // input strides to be as input[i + 12j + 4k] so we may process 207 // layer-by-layer. 208 209 // Step 1/2: Strides for input. We ignore output since row-major and can just 210 // push_back. 211 212 SmallVector<int64_t> originalInputStrides(rank); 213 originalInputStrides[rank - 1] = 1; 214 // index with int64_t to avoid overflow 215 for (int64_t i = rank - 2; i >= 0; i--) 216 originalInputStrides[i] = 217 originalInputStrides[i + 1] * oldType.getDimSize(i + 1); 218 219 // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as 220 // output which is done in row-major order. 221 222 SmallVector<int64_t> newInputStrides; 223 newInputStrides.reserve(rank); 224 for (int32_t v : perms) 225 newInputStrides.push_back(originalInputStrides[v]); 226 227 // Step 4: Write out the transposed "flat array" dimension by dimension. 228 229 auto inputArray = input.getValues<Attribute>(); 230 SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides; 231 for (size_t i = 0; i < rank; i++) 232 boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]}); 233 234 SmallVector<Attribute> resultArray; 235 resultArray.reserve(inputArray.size()); 236 237 std::function<void(int64_t, 238 SmallVector<std::pair<int64_t, int64_t>>::const_iterator)> 239 processTransposeDim = [&](auto accumulatedIndex, auto it) { 240 if (it == boundsAndStrides.end()) { 241 resultArray.push_back(inputArray[accumulatedIndex]); 242 return; 243 } 244 245 for (int64_t i = 0; i < it->first; i++) { 246 int64_t j = accumulatedIndex + i * it->second; 247 processTransposeDim(j, it + 1); 248 } 249 }; 250 251 processTransposeDim(0, boundsAndStrides.begin()); 252 253 return DenseElementsAttr::get(newType, resultArray); 254 } 255 256 // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp 257 // as the sources of the data dependencies, and TosaElementWiseOperator 258 // after that, if the function returns true. 259 bool TosaReduceTransposes::collectFanIn(Operation *op, 260 SetVector<Operation *> &collected) { 261 // Can occur if defined through the parameter to a func.func. 262 if (!op) 263 return false; 264 265 if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect())) 266 return false; 267 268 // Prevent extra work if already seen. 269 if (collected.contains(op)) 270 return true; 271 272 // Throw it out so later don't have to deal with this. 273 if (op->getNumResults() != 1 || 274 !llvm::isa<RankedTensorType>(op->getResult(0).getType())) 275 return false; 276 277 // We don't wish to traverse up a ReshapeOp, since generally we can't 278 // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp 279 // will have no in-edges in the data dependency graph we construct for 280 // the downstream TransposeOp. 281 if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) && 282 !llvm::isa<tosa::ConstOp>(op)) { 283 284 if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()) 285 return false; 286 287 for (Value operand : op->getOperands()) 288 // If this is a problem in future, think about alternatives to recursion. 289 if (!collectFanIn(operand.getDefiningOp(), collected)) 290 return false; 291 } 292 293 // Insert in topological order. 294 collected.insert(op); 295 296 return true; 297 } 298 299 // Assuming that due to the verification of TransposeOp perms arrays are 300 // permutations of 0 - perms.size() - 1. 301 bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1, 302 ArrayRef<int32_t> perms2) { 303 if (perms1.size() != perms2.size()) 304 return false; 305 int32_t n = perms1.size(); 306 for (int32_t i = 0; i < n; i++) 307 if (perms2[perms1[i]] != i) 308 return false; 309 return true; 310 } 311 312 // Primary overload for those with TosaElementwiseOperator trait. 313 // The other ones handle the case of the operations that occur at the 314 // roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp). 315 std::optional<Value> TosaReduceTransposes::buildMappedToValue( 316 Operation *op, const DenseMap<Value, Value> &valuesMap, 317 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 318 if (op->getNumResults() != 1 || 319 !op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>()) 320 return std::nullopt; 321 322 auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType()); 323 SmallVector<Value, 3> operands; 324 for (Value v : op->getOperands()) { 325 if (valuesMap.contains(v)) { 326 operands.push_back(valuesMap.at(v)); 327 } else { 328 return std::nullopt; 329 } 330 } 331 332 // Conceptually, we propagate the hoisted TransposeOp through 333 // these interveaning operations. For example, 334 335 // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32> 336 // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) -> 337 // tensor<3x2xi32> 338 339 // becomes: 340 // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) -> 341 // tensor<3x2xi32> 342 // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>) 343 344 // We construct this new tosa.clamp here, but it doesn't 345 // turn "live" until the transpose being hoisted through this chain 346 // is replaced with the proper value from the new chain. 347 348 return rewriter 349 .create(op->getLoc(), op->getName().getIdentifier(), operands, 350 RankedTensorType::get( 351 applyTOSAPermutation(resultType.getShape(), hoistedPerms), 352 resultType.getElementType()), 353 op->getAttrs()) 354 ->getResult(0); 355 } 356 357 std::optional<Value> TosaReduceTransposes::buildMappedToValue( 358 TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap, 359 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 360 SmallVector<int32_t> perms; 361 if (failed(transposeOp.getConstantPerms(perms)) || 362 !areInvolutionTransposes(hoistedPerms, perms)) 363 return std::nullopt; 364 return transposeOp.getInput1(); 365 } 366 367 std::optional<Value> TosaReduceTransposes::buildMappedToValue( 368 ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap, 369 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 370 auto reshapeOutput = reshapeOp.getOutput(); 371 auto reshapeInputType = 372 llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType()); 373 auto reshapeInputShape = reshapeInputType.getShape(); 374 // want reshape N -> 1x1x...x1xNx1x...x1x1 375 if (!reshapeInputType || reshapeInputShape.size() != 1) 376 return std::nullopt; 377 auto reshapeOutputType = 378 llvm::cast<RankedTensorType>(reshapeOutput.getType()); 379 380 // Instead of inserting a TransposeOp here, we check if we can fold it into 381 // the ReshapeOp. There is more complex cases where this is possible, and 382 // this check can be extended. 383 384 // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1 385 auto shape = reshapeOutputType.getShape(); 386 size_t ones = llvm::count(shape, 1); 387 // N == 1 and N != 1 388 if (ones != shape.size() - 1 && 389 !(ones == shape.size() && reshapeInputShape[0] == 1)) 390 return std::nullopt; 391 392 // Do not insert a TransposeOp, instead we fold the reshape and its attribute. 393 auto foldedReshape = rewriter.create<ReshapeOp>( 394 reshapeOp.getLoc(), 395 RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms), 396 reshapeOutputType.getElementType()), 397 reshapeOp.getInput1(), 398 rewriter.getDenseI64ArrayAttr( 399 applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms))); 400 return foldedReshape->getResult(0); 401 } 402 403 std::optional<Value> TosaReduceTransposes::buildMappedToValue( 404 ConstOp constOp, const DenseMap<Value, Value> &valuesMap, 405 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 406 auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue()); 407 if (!denseAttr) 408 return std::nullopt; 409 auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms); 410 if (!maybeNewDenseAttr.has_value()) 411 return std::nullopt; 412 auto newDenseAttr = maybeNewDenseAttr.value(); 413 auto newConstOp = rewriter.create<ConstOp>( 414 constOp.getLoc(), newDenseAttr.getType(), newDenseAttr); 415 return newConstOp->getResult(0); 416 } 417 418 bool TosaReduceTransposes::convertDependentOps( 419 SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap, 420 IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) { 421 422 for (Operation *op : dependentOps) { 423 if (!op || op->getNumResults() != 1) 424 return false; 425 426 Value priorValue = op->getResult(0); 427 428 // It's possible on a prior transposeOp we had the same dependency and 429 // already resolved it. 430 if (valuesMap.contains(priorValue)) 431 continue; 432 433 // Keep converted ops close to the original. 434 rewriter.setInsertionPointAfter(op); 435 436 std::optional<Value> maybeValue = 437 llvm::TypeSwitch<Operation *, std::optional<Value>>(op) 438 .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) { 439 return buildMappedToValue(transposeOp, valuesMap, rewriter, 440 hoistedPerms); 441 }) 442 .Default([&](Operation *op) { 443 return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms); 444 }); 445 446 if (!maybeValue.has_value()) 447 return false; 448 449 valuesMap[priorValue] = maybeValue.value(); 450 } 451 452 return true; 453 } 454 455 bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies( 456 Operation *user, std::set<TransposeOp> &validTransposes, 457 std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 458 &transposeInfo) { 459 return llvm::none_of( 460 transposeInfo, 461 [&validTransposes, 462 user](const std::pair<TransposeOp, SetVector<Operation *>> &info) { 463 const auto &[transposeOp, dependentOps] = info; 464 return validTransposes.count(transposeOp) && 465 dependentOps.contains(user); 466 }); 467 } 468 469 // Dependencies are valid for an operation if none of them occur outside 470 // of the proper fan-in cones of the hoisted TransposeOp with the same perms 471 // that we can replace. Described in more detail within. 472 bool TosaReduceTransposes::dependenciesAreValid( 473 ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps, 474 std::set<TransposeOp> &validTransposes, 475 std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 476 &transposeInfo) { 477 for (Operation *op : dependentOps) { 478 479 // It's OK wherever ConstOp has uses -- in the worst case, we duplicate. 480 // This can be changed later if we find the memory impact is too high. 481 if (llvm::isa<ConstOp>(op)) 482 continue; 483 484 for (OpOperand &use : op->getUses()) { 485 // Want the uses to be (1) contained in the dependentOps of other 486 // validTransposes, or (2) to be directly used in a TransposeOp with the 487 // same perms. For (2) it means the fan-in is a subset of our 488 // dependentOps, so it is also a validTranspose that will eventually be 489 // replaced. 490 Operation *user = use.getOwner(); 491 if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) { 492 SmallVector<int32_t> otherPerms; 493 494 // Can later think about cases where transpose -> transpose 495 // or reshape -> transpose, where the transposes are not necessarily 496 // the same perms as the hoisted, if implementing a more general 497 // transform. These could be permitted. 498 if (failed(otherTranspose.getConstantPerms(otherPerms)) || 499 !llvm::equal(perms, otherPerms)) 500 return false; 501 } else if (userNotContainedInValidTransposeDependencies( 502 user, validTransposes, transposeInfo)) { 503 return false; 504 } 505 } 506 } 507 508 return true; 509 } 510 511 // Getting the set of TransposeOp that we can replace without causing 512 // the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being 513 // dead code. This is done by iterating the set until convergence, since 514 // if you are used outside your own fan-in cone, it's possible to be used 515 // in another fan-in cone of a TransposeOp that is being replaced -- unless 516 // we find that that one has a usage outside of it too. 517 std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements( 518 ArrayRef<int32_t> perms, 519 std::vector<std::pair<TransposeOp, SetVector<Operation *>>> 520 &transposeInfo) { 521 // Initially, we assume they are all good to replace, 522 // and we whittle them down based on our criteria. 523 std::set<TransposeOp> ableToReplace; 524 for (const auto &[transposeOp, _] : transposeInfo) 525 ableToReplace.insert(transposeOp); 526 527 bool gotRid; 528 do { 529 gotRid = false; 530 for (const auto &[transposeOp, dependentOps] : transposeInfo) { 531 // We don't care about it. Already invalidated. 532 if (!ableToReplace.count(transposeOp)) 533 continue; 534 535 // Check for validity. 536 if (!dependenciesAreValid(perms, dependentOps, ableToReplace, 537 transposeInfo)) { 538 ableToReplace.erase(transposeOp); 539 gotRid = true; 540 break; 541 } 542 } 543 544 } while (gotRid); 545 546 return ableToReplace; 547 } 548 549 void TosaReduceTransposes::runOnOperation() { 550 // We want to operate only within a single block. 551 if (!getOperation().getRegion().hasOneBlock()) 552 return; 553 554 IRRewriter rewriter(&getContext()); 555 // For each perms, maintain a mapping for converted ops, avoid duplication. 556 DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues; 557 // For each perms, we keep track of which TransposeOp are eligible 558 // for replacement alongside their dependentOps. 559 DenseMap<ArrayRef<int32_t>, 560 std::vector<std::pair<TransposeOp, SetVector<Operation *>>>> 561 permsToTransposeInfo; 562 563 // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef. 564 // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise 565 // since no guarantee of smallness. 566 std::vector<SmallVector<int32_t>> collectedPerms; 567 568 // This keeps track of the order across all eligible-for-replacement 569 // TransposeOp and their perms, a necessity for the final replacements. 570 std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder; 571 572 // We want to reserve the space up front, since SmallVector stores some data 573 // internally and the ArrayRef can reference that, which we don't want to get 574 // invalidated. 575 size_t expectedMaxPerms = 0; 576 getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; }); 577 collectedPerms.reserve(expectedMaxPerms); 578 579 getOperation().walk([&](TransposeOp transposeOp) { 580 SetVector<Operation *> dependentOps; 581 collectedPerms.emplace_back(); 582 SmallVector<int32_t> &perms = collectedPerms.back(); 583 584 // Dynamic shapes are OK, but the incompatible ones will be rejected later. 585 auto input = transposeOp.getInput1(); 586 auto output = transposeOp.getOutput(); 587 588 // However, we don't support unranked tensors. 589 if (!llvm::isa<RankedTensorType>(input.getType()) || 590 !llvm::isa<RankedTensorType>(output.getType())) 591 return; 592 593 // No transformation when transpose permutation non-constant. 594 if (failed(transposeOp.getConstantPerms(perms))) 595 return; 596 597 // We let --canonicalize deal with identity transpose. 598 if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms)) 599 return; 600 601 // Can fail if some set of basic invariants is not met that we want to 602 // perform our conversions. 603 if (!collectFanIn(input.getDefiningOp(), dependentOps)) 604 return; 605 606 // Want to associate valuesMap for already converted of the same perms, 607 // since it's possible multiple hoisted transposes w/ different perms 608 // converge on an op, which would result in different transformations. 609 DenseMap<Value, Value> &valuesMap = permsToValues[perms]; 610 611 // Attempt to perform the conversions and placements into IR 612 // without turning inserted code "live". Also fills out valuesMap. 613 // Fails if there is an intermediary we do not support. 614 if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms)) 615 // Some additional operations may have been inserted, but will be 616 // removed by dead code elimination. 617 return; 618 619 // This should not happen. If it does -- it's unexpected, 620 // so we fail the pass. 621 if (!valuesMap.contains(input)) 622 return signalPassFailure(); 623 624 // It's possible the types are not compatible (because of dynamic shapes), 625 // and in these cases, want to resolve dynamic shapes before running the 626 // pass. 627 if (output.getType() != valuesMap.at(input).getType()) 628 return; 629 630 auto &transposeInfo = permsToTransposeInfo[perms]; 631 632 // In general, we might also want to introduce "newDependentOps" 633 // if there are new usages that don't fall inside the original fan-ins 634 // (like the TransposeOp we insert for ReshapeOp), 635 // but in this case, that is specialized enough and overlaps 636 // with another direct-use TransposeOp case we need to cover anyway. 637 transposeInfo.push_back({transposeOp, dependentOps}); 638 639 // This is for the final replacement across all transposes. 640 totalTransposeOrder.push({transposeOp, perms}); 641 }); 642 643 // We want to do a full fan-in analysis on a perms-level, 644 // since if we do it on a multi-perms level, and they share (due to a shared 645 // dependency on a Reshape) then we would also get duplicate ops. 646 // Const is special cased. 647 std::set<TransposeOp> ableToReplace; 648 for (auto &[perms, transposeInfo] : permsToTransposeInfo) { 649 // Gives us back replacements that would never result in any duplicate 650 // operations being inserted by us in the IR (i.e, our goal is only to 651 // remove transposes, and not create a "new chain" to do so, but replace 652 // the existing chains). 653 // Ideally, --canonicalize is run before this pass, since it helps this 654 // analysis by removing dead code to allow more potentially acceptable 655 // transformations. 656 auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo); 657 ableToReplace.insert(goodReplacementsForPerms.begin(), 658 goodReplacementsForPerms.end()); 659 } 660 661 // We want to do replacement across all transposes 662 // in reverse order, due to invalidation of valuesMap mappings 663 // if we did it otherwise. 664 while (!totalTransposeOrder.empty()) { 665 auto [transposeOp, perms] = totalTransposeOrder.top(); 666 totalTransposeOrder.pop(); 667 668 if (ableToReplace.count(transposeOp) == 0) 669 continue; 670 671 auto &valuesMap = permsToValues[perms]; 672 auto input = transposeOp.getInput1(); 673 674 // The purpose of this reverse iteration 675 // is to avoid valuesMap invalidation. If it happens, 676 // something is wrong. 677 if (!valuesMap.contains(input)) 678 return signalPassFailure(); 679 680 rewriter.replaceOp(transposeOp, valuesMap.at(input)); 681 } 682 683 // We can remove all dead code by going in reverse. 684 // This is because we would remove usages before we 685 // see the users. 686 getOperation().walk<WalkOrder::PostOrder, ReverseIterator>( 687 [&](Operation *op) { 688 if (isOpTriviallyDead(op)) 689 rewriter.eraseOp(op); 690 }); 691 } 692 693 } // namespace 694