1 //===- HoistPadding.cpp - Hoisting for tensor::PadOp ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting padding operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/Presburger/IntegerRelation.h" 14 #include "mlir/Analysis/SliceAnalysis.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Affine/Transforms/Transforms.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/Dialect/Tensor/Utils/Utils.h" 23 #include "mlir/Dialect/Utils/IndexingUtils.h" 24 #include "mlir/IR/AsmState.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/IR/Matchers.h" 27 #include "mlir/Interfaces/DestinationStyleOpInterface.h" 28 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 29 #include "mlir/Transforms/RegionUtils.h" 30 #include "llvm/Support/Debug.h" 31 32 using llvm::dbgs; 33 34 #define DEBUG_TYPE "hoist-padding" 35 36 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 using namespace mlir::linalg::detail; 41 42 #ifndef NDEBUG 43 static bool debugPrintLoopInShortForm(Operation *op) { 44 AsmState state(op->getParentOfType<func::FuncOp>()); 45 (void)state; 46 if (auto forOp = dyn_cast<scf::ForOp>(op)) { 47 forOp.getInductionVar().printAsOperand(dbgs(), state); 48 dbgs() << " @ " << forOp.getOperation(); 49 return true; 50 } 51 return false; 52 } 53 #endif 54 55 static void debugPrintBackwardSlice(SetVector<Operation *> &backwardSlice) { 56 LLVM_DEBUG(llvm::interleaveComma(backwardSlice, DBGS() << "--backwardSlice:", 57 [](Operation *op) { 58 dbgs() << "\n"; 59 DBGS() << "----"; 60 if (debugPrintLoopInShortForm(op)) { 61 dbgs() << "\n"; 62 return; 63 } 64 dbgs() << *op << "\n"; 65 }); 66 DBGS() << "\n";); 67 } 68 69 /// Return at most nLevels of immediately enclosing scf::ForOp loops. 70 /// Stops at the first parent that is not an scf::ForOp. 71 /// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm. 72 /// Control-flow and other containing ops with regions are not modeled atm. 73 static void 74 getAtMostNEnclosingLoops(tensor::PadOp padOp, int nLevels, 75 SmallVector<scf::ForOp> &reverseEnclosingLoops) { 76 scf::ForOp outermostEnclosingForOp = nullptr; 77 Operation *nextEnclosingOp = padOp->getParentOp(); 78 while (nLevels-- > 0 && 79 (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) { 80 LLVM_DEBUG(DBGS() << "loops: "; 81 debugPrintLoopInShortForm(outermostEnclosingForOp); 82 dbgs() << "\n"); 83 reverseEnclosingLoops.push_back(outermostEnclosingForOp); 84 nextEnclosingOp = outermostEnclosingForOp->getParentOp(); 85 } 86 } 87 88 /// Return at most nLevels of immediately enclosing scf::ForOp loops. 89 /// Stops at the first parent that is not an scf::ForOp. 90 /// Multi-loops such as scf.parallel or linalg.tiled_loop are not modeled atm. 91 /// Control-flow and other containing ops with regions are not modeled atm. 92 static void 93 getEnclosingLoopsUntil(tensor::PadOp padOp, scf::ForOp untilLoop, 94 SmallVector<scf::ForOp> &reverseEnclosingLoops) { 95 scf::ForOp outermostEnclosingForOp = nullptr; 96 Operation *nextEnclosingOp = padOp->getParentOp(); 97 while (outermostEnclosingForOp != untilLoop && 98 (outermostEnclosingForOp = dyn_cast<scf::ForOp>(nextEnclosingOp))) { 99 LLVM_DEBUG(DBGS() << "loops: "; 100 debugPrintLoopInShortForm(outermostEnclosingForOp); 101 dbgs() << "\n"); 102 reverseEnclosingLoops.push_back(outermostEnclosingForOp); 103 nextEnclosingOp = outermostEnclosingForOp->getParentOp(); 104 } 105 } 106 107 // Get all the ops in the backwards slice starting from `padOp` and that 108 // are dominated by the outermost enclosing loop. 109 // This also requires tracking ops defining values used in the region but 110 // defined above. 111 static void computeBackwardSlice(tensor::PadOp padOp, 112 scf::ForOp outermostEnclosingForOp, 113 SetVector<Operation *> &backwardSlice) { 114 DominanceInfo domInfo(outermostEnclosingForOp); 115 BackwardSliceOptions sliceOptions; 116 sliceOptions.filter = [&](Operation *op) { 117 return domInfo.dominates(outermostEnclosingForOp, op) && 118 !padOp->isProperAncestor(op); 119 }; 120 sliceOptions.inclusive = true; 121 122 // First, add the ops required to compute the region to the backwardSlice. 123 SetVector<Value> valuesDefinedAbove; 124 getUsedValuesDefinedAbove(padOp.getRegion(), padOp.getRegion(), 125 valuesDefinedAbove); 126 for (Value v : valuesDefinedAbove) { 127 getBackwardSlice(v, &backwardSlice, sliceOptions); 128 } 129 // Then, add the backward slice from padOp itself. 130 getBackwardSlice(padOp.getOperation(), &backwardSlice, sliceOptions); 131 } 132 133 //===----------------------------------------------------------------------===// 134 // HoistPaddingAnalysis Implementation. 135 //===----------------------------------------------------------------------===// 136 137 namespace { 138 /// Analysis class to support tensor::PadOp hoisting across multiple enclosing 139 /// loops. The failure conditions are: 140 /// 1. Pad op has a use that is not an input of a LinalgOp. 141 /// 2. Pad op does not have a constant padding value. 142 /// 3. There is no immediately enclosing scf::ForOp. 143 /// 4. The backward slice from the pad op to the scf::ForOp to hoist above 144 /// contains an unknown op with non index type operands, a region, or a 145 /// memory effect. 146 /// 5. The backward slice from the pad op to the scf::ForOp to hoist above is 147 /// empty. 148 /// 6. The source tensor of pad op is not defined by an extract slice op. 149 /// 7. The source tensor of the extract slice op is not defined outside of 150 /// the outermost enclosing scf::ForOp. 151 /// 8. There is no enclosing scf::ForOp that indexes the padded data. 152 /// Other cases succeed and will trigger hoisting of the pad op. 153 struct HoistPaddingAnalysis { 154 HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops); 155 HoistPaddingAnalysis(tensor::PadOp padOp, scf::ForOp outermostEnclosingForOp); 156 157 bool isValid() { return valid.has_value() && valid.value(); } 158 bool isInvalid() { return valid.has_value() && !valid.value(); } 159 160 /// Footprint of the hoistedPackedTensor, computed from the packingLoops. 161 SmallVector<Value> getHoistedPackedTensorSizes(RewriterBase &rewriter, 162 Location loc) const; 163 164 /// Performs optional hoisting to enable hoist padding to occur. This may be 165 /// necessary when `sliceOp` is not defined outside of the outermost enclosing 166 /// loop we want to hoist above. 167 /// 168 /// Example: 169 /// ``` 170 /// %source = linalg.fill(%cst, %arg0) 171 /// // %source is available for packing here! 172 /// scf.for %i 173 /// scf.for %j 174 /// scf.for %k 175 /// %slice = tensor.extract_slice %source [%i, %j] 176 /// %padded_slice = tensor.pad %slice 177 /// ``` 178 void enableHoistPadding(RewriterBase &rewriter); 179 180 /// Common analysis builder to finalize the construction of the analysis once 181 /// optional `enableHoistPadding` has run. 182 /// `reverseEnclosingLoops.back()` is the loop to hoist above. 183 void finalizeHoistPaddingAnalysis(); 184 185 private: 186 /// Encodes whether the analysis is valid and hoisting can proceed. 187 std::optional<bool> valid; 188 189 /// The padOp to hoist. 190 tensor::PadOp opToHoist; 191 192 /// Immediately enclosing loops considered for hoisting padding. 193 SmallVector<scf::ForOp> reverseEnclosingLoops; 194 195 /// Drop any non-index dependencies of `padOp` and `sliceOp` from 196 /// `backwardSlice`. The method follows the use-def chains of the index 197 /// operands consumed by `padOp` and `sliceOp` and drops the operations 198 /// not part of this index computation. Afterwards, the filtered 199 /// `backwardSlice` contains only the loops whose induction variable is 200 /// used, directly or indirectly, to index the padded tensor. The method 201 /// returns failure if the filtered backward slice contains an unexpected 202 /// operation. 203 /// 204 /// Example: 205 /// ``` 206 /// %source = linalg.fill(%cst, %arg0) 207 /// scf.for %i 208 /// %unrelated = linalg.fill(%cst, %arg1) // not used to index 209 /// %source! scf.for %j (%arg2 = %unrelated) 210 /// scf.for %k // not used to index 211 /// %source! 212 /// %ubi = affine.min #map(%i) 213 /// %ubj = affine.min #map(%j) 214 /// %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj] 215 /// %padded_slice = tensor.pad %slice 216 /// ``` 217 /// dropNonIndexDependencies(%padded_slice, %slice) 218 /// removes [scf.for %k, linalg.fill(%cst, %arg1)] from backwardSlice. 219 LogicalResult dropNonIndexDependencies(); 220 221 public: 222 /// The outermost loop, determined by `nLevels` above which `padOp` will 223 /// be hoisted. 224 scf::ForOp outermostEnclosingForOp; 225 226 /// Backward slice rooted at `padOp` and nested under 227 /// `outermostEnclosingForOp`. 228 SetVector<Operation *> backwardSlice; 229 230 /// The scf::ForOp immediately enclosing `padOp` such that: 231 /// 1. they are nested under `outermostEnclosingForOp` (inclusive) 232 /// 2. whose induction variable is used, directly or indirectly, in the 233 /// computation of `padOp`. 234 /// The span of these loops determines the footprint of the packed tensor. 235 SmallVector<scf::ForOp> packingLoops; 236 237 /// The ExtractSliceOp that feeds the PadOp we want to hoist. 238 tensor::ExtractSliceOp sliceOp; 239 240 /// If non-empty, this is the unique scf::ForOp that consumes the `sliceOp`. 241 scf::ForOp padConsumingForOp; 242 }; 243 244 } // namespace 245 246 HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, int numLoops) 247 : valid(std::nullopt), opToHoist(padOp) { 248 // Get at most `numLoops` of immediately enclosing loops. 249 getAtMostNEnclosingLoops(opToHoist, numLoops, reverseEnclosingLoops); 250 if (reverseEnclosingLoops.empty()) { 251 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n"); 252 valid = false; 253 return; 254 } 255 outermostEnclosingForOp = reverseEnclosingLoops.back(); 256 sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>(); 257 if (!sliceOp) { 258 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n"); 259 valid = false; 260 return; 261 } 262 } 263 264 HoistPaddingAnalysis::HoistPaddingAnalysis(tensor::PadOp padOp, 265 scf::ForOp outermostEnclosingForOp) 266 : valid(std::nullopt), opToHoist(padOp) { 267 // Get enclosing loops until outermostEnclosingForOp. 268 getEnclosingLoopsUntil(opToHoist, outermostEnclosingForOp, 269 reverseEnclosingLoops); 270 if (reverseEnclosingLoops.empty()) { 271 LLVM_DEBUG(DBGS() << "--No immediately enclosing loop -> Skip\n"); 272 valid = false; 273 return; 274 } 275 this->outermostEnclosingForOp = reverseEnclosingLoops.back(); 276 if (this->outermostEnclosingForOp != outermostEnclosingForOp) { 277 LLVM_DEBUG(DBGS() << "--Unexpected outermost enclosing loop -> Skip\n"); 278 valid = false; 279 return; 280 } 281 sliceOp = opToHoist.getSource().getDefiningOp<tensor::ExtractSliceOp>(); 282 if (!sliceOp) { 283 LLVM_DEBUG(DBGS() << "--Cannot find the extract slice op -> Skip\n"); 284 valid = false; 285 return; 286 } 287 } 288 289 void HoistPaddingAnalysis::enableHoistPadding(RewriterBase &rewriter) { 290 if (isInvalid()) 291 return; 292 // If the padded data is not yet available before entering the outermost 293 // enclosing loop, try to apply hoisting on this outermost loop. 294 // TODO: we may want finer-grained hoisting of only that particular `sliceOp`. 295 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) { 296 outermostEnclosingForOp = cast<scf::ForOp>( 297 hoistLoopInvariantSubsets(rewriter, outermostEnclosingForOp)); 298 } 299 } 300 301 void HoistPaddingAnalysis::finalizeHoistPaddingAnalysis() { 302 if (isInvalid()) 303 return; 304 305 if (!outermostEnclosingForOp.isDefinedOutsideOfLoop(sliceOp.getSource())) { 306 LLVM_DEBUG(DBGS() << "--outermostEnclosingForOp:\n" 307 << outermostEnclosingForOp << "\n" 308 << "--sliceOp: " << sliceOp << "\n" 309 << "--sliceOp.getSource(): " << sliceOp.getSource() 310 << "\n"); 311 LLVM_DEBUG(DBGS() << "----Source not defined outside of loops -> Skip\n"); 312 valid = false; 313 return; 314 } 315 if (sliceOp->hasOneUse()) { 316 padConsumingForOp = dyn_cast<scf::ForOp>(*(sliceOp->getUsers().begin())); 317 } 318 319 // Check the region of `padOp` depends on a constant only. Adding hoisting 320 // support for arbitrary padding regions would require cloning all 321 // dependencies captured by the padding region. 322 Value paddingValue = opToHoist.getConstantPaddingValue(); 323 if (!paddingValue || 324 !isa_and_nonnull<arith::ConstantOp>(paddingValue.getDefiningOp())) { 325 LLVM_DEBUG(DBGS() << "Cannot find constant padding value -> Skip\n"); 326 valid = false; 327 return; 328 } 329 330 computeBackwardSlice(opToHoist, outermostEnclosingForOp, backwardSlice); 331 if (backwardSlice.size() <= 1) { 332 valid = false; 333 return; 334 } 335 336 debugPrintBackwardSlice(backwardSlice); 337 // Remove all ops in the backward slice that are not used to index 338 // the padded tensor. In particular, keep `padOp`, `sliceOp`, and 339 // the loop and affine operations used for the index computation. 340 if (failed(dropNonIndexDependencies())) { 341 LLVM_DEBUG(DBGS() << "--Cannot dropNonIndexDependencies -> Skip\n"); 342 valid = false; 343 return; 344 } 345 debugPrintBackwardSlice(backwardSlice); 346 347 // Add only the loops part of the filtered `backwardSlice` to the 348 // packing loops. All other loops are not used to index the padded 349 // data and consequently access the same data in every loop 350 // iteration. Adding them to the packing loops would increase the 351 // cache footprint of the packed data by storing the same data 352 // multiple times. 353 for (scf::ForOp forOp : llvm::reverse(reverseEnclosingLoops)) 354 if (backwardSlice.contains(forOp)) 355 packingLoops.push_back(forOp); 356 357 // TODO: for multiple loops we need to track the use to the innermost loop. 358 if (packingLoops.size() > 1 && padConsumingForOp) { 359 LLVM_DEBUG(DBGS() << "--Cannot hoist multiple loops through iter_args -> " 360 "Downgrade to 1 loop\n"); 361 packingLoops.resize(1); 362 } 363 364 // Note: at this point, packing loops may be empty but we would still like 365 // to hoist the padding if so specified. 366 367 // The analysis is valid and hoisting can occur. 368 valid = true; 369 } 370 371 LogicalResult HoistPaddingAnalysis::dropNonIndexDependencies() { 372 // Set of all values used for index computation. 373 SetVector<Value> indexEdges; 374 375 // Add all index operands of `operation` to `indexEdges`. An index operand 376 // is an operand of type index. 377 auto addIndexOperandsToIndexEdges = [&](Operation *operation) { 378 for (Value operand : operation->getOperands()) 379 if (operand.getType().isIndex()) 380 indexEdges.insert(operand); 381 }; 382 383 // Check if any operation result is contained in `indexEdges`. 384 auto hasIndexResult = [&](Operation *operation) { 385 return llvm::any_of(operation->getResults(), [&](Value result) { 386 return indexEdges.contains(result); 387 }); 388 }; 389 390 // Starting from `opToHoist` and `sliceOp` walk the use-def edges of index 391 // type in `backwardSlice`. Add the index operands of an operation to 392 // `indexEdges` and remove all operations from `backwardSlice` that are not 393 // part of the index computation. 394 // 395 // Example: 396 // ``` 397 // %source = linalg.fill(%cst, %arg0) 398 // scf.for %i 399 // %unrelated = linalg.fill(%cst, %arg1) // not used to index %source! 400 // scf.for %j (%arg2 = %unrelated) 401 // scf.for %k // not used to index %source! 402 // %ubi = affine.min #map(%i) 403 // %ubj = affine.min #map(%j) 404 // %slice = tensor.extract_slice %source [%i, %j] [%ubi, %ubj] 405 // %padded_slice = tensor.pad %slice 406 // ``` 407 // After iterating `backwardSlice` we obtain: 408 // indexEdges = [%i, %j, %ubi, %ubj] 409 // backwardSlice = backwardSlice / [linalg.fill(%cst, %arg1), scf.for %k] 410 SetVector<Operation *> operationsToRemove; 411 for (Operation *op : llvm::reverse(backwardSlice)) { 412 // Add the index operands of `opToHoist` and `sliceOp` to start the 413 // exploration of the index computation. 414 if (op == opToHoist || op == sliceOp) { 415 addIndexOperandsToIndexEdges(op); 416 continue; 417 } 418 // Add the index operands of the loop if its induction variable is 419 // used for index computation. 420 if (auto forOp = dyn_cast<scf::ForOp>(op)) { 421 if (!hasIndexResult(op) && indexEdges.contains(forOp.getInductionVar())) { 422 addIndexOperandsToIndexEdges(op); 423 continue; 424 } 425 } 426 // Add the index operands of all other operations if at least one result 427 // is used for index computation. 428 if (hasIndexResult(op)) { 429 addIndexOperandsToIndexEdges(op); 430 // Check the operands of the remaining operations all have index type. 431 if (llvm::any_of(op->getOperandTypes(), 432 [](Type type) { return !type.isIndex(); })) { 433 LLVM_DEBUG(DBGS() << "Unsupported op with non index type operands: " 434 << op << " -> Skip\n"); 435 return failure(); 436 } 437 // Check the remaining operations do not have regions or memory effects. 438 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(op); 439 bool hasMemoryEffect = effectInterface && !effectInterface.hasNoEffect(); 440 if (hasMemoryEffect || op->getNumRegions() != 0) { 441 LLVM_DEBUG(DBGS() << "Unsupported op with region or memory effect: " 442 << op << " -> Skip\n"); 443 return failure(); 444 } 445 continue; 446 } 447 // Remove all other operations not used by the index computation. An 448 // exception are constant operations that may be used by `opToHoist`. 449 if (!isa<arith::ConstantOp>(op)) 450 operationsToRemove.insert(op); 451 } 452 backwardSlice.set_subtract(operationsToRemove); 453 return success(); 454 } 455 456 SmallVector<Value> 457 HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter, 458 Location loc) const { 459 SmallVector<Value> dynamicTensorSizes; 460 461 // Upper bound the packing loop lengths to size the packed tensor. Taking 462 // upper bounds can make the sizes of the packed tensor independent of the 463 // enclosing loops. This independence is a prerequisite for reusing the same 464 // buffer for all enclosing loop iterations and hoisting its allocation out 465 // of the enclosing loops. 466 for (auto forOp : packingLoops) { 467 // Compute an upper bound `ubVal` for the upper bound of `forOp`. 468 FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound( 469 rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(), 470 /*stopCondition=*/ 471 [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { 472 if (v == forOp.getUpperBound()) 473 return false; 474 // Compute a bound that is independent of any affine op results. 475 Operation *op = v.getDefiningOp(); 476 if (!op) 477 return true; 478 return !isa<affine::AffineMinOp, affine::AffineMaxOp, 479 affine::AffineApplyOp>(op); 480 }, 481 /*closedUB=*/true); 482 assert(succeeded(loopUb) && "could not get upper bound"); 483 Value ubVal = getValueOrCreateConstantIndexOp(rewriter, loc, *loopUb); 484 485 // Compute the maximal packing loop length as (ub - lb).ceilDiv(step) and 486 // store the result to `dynamicTensorSizes`. 487 // TODO: instead of using the lower bound of `forOp` directly, implement a 488 // lower bound computation similar to the upper bound computation. 489 AffineExpr lb, ub, step; 490 bindDims(rewriter.getContext(), lb, ub); 491 bindSymbols(rewriter.getContext(), step); 492 Value res = rewriter.createOrFold<affine::AffineApplyOp>( 493 loc, (ub - lb).ceilDiv(step), 494 ValueRange{forOp.getLowerBound(), ubVal, 495 cast<scf::ForOp>(forOp).getStep()}); 496 dynamicTensorSizes.push_back(res); 497 } 498 499 return dynamicTensorSizes; 500 } 501 502 static bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) { 503 return outer.isDefinedOutsideOfLoop(v) || matchPattern(v, m_Constant()); 504 } 505 506 //===----------------------------------------------------------------------===// 507 // buildPackingLoopNest Implementation. 508 //===----------------------------------------------------------------------===// 509 510 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step). 511 /// The returned Value is guaranteed not to depend on any loop comprised in 512 /// [`outer`, `forOp`]. 513 /// Return null if such a loop-independent quantity cannot be computed. 514 static Value buildLoopIterationCount(RewriterBase &rewriter, scf::ForOp outer, 515 scf::ForOp forOp) { 516 MLIRContext *ctx = forOp->getContext(); 517 AffineExpr iv, lb, step; 518 bindDims(ctx, iv, lb); 519 bindSymbols(ctx, step); 520 if (!isDefinedOutsideOrConstant(outer, forOp.getLowerBound()) || 521 !isDefinedOutsideOrConstant(outer, forOp.getStep())) 522 return Value(); 523 Value ivVal = forOp.getInductionVar(), lbVal = forOp.getLowerBound(), 524 stepVal = forOp.getStep(); 525 auto loc = forOp->getLoc(); 526 return rewriter.createOrFold<affine::AffineApplyOp>( 527 loc, (iv - lb).ceilDiv(step), ValueRange{ivVal, lbVal, stepVal}); 528 } 529 530 // Build a packing loop nest by iteratively traversing the backward slice and 531 // clone the operations, iteratively stepping into the loops that we encounter. 532 // The implementation proceeds in a stack-like fashion: 533 // 1. Iteratively clone and step into the loops, pushing the 534 // `hoistedPackedTensor` 535 // deeper in the stack. 536 // 2. At the innermost loop level, create a GenericOp if `transposeVector` is 537 // non-empty. 538 // 3. At the innermost loop level, create a InsertSliceOp. 539 // 4. Iteratively pop and yield the result of the InsertSliceOp across the 540 // cloned loops. 541 static FailureOr<PackingResult> buildPackingLoopNestImpl( 542 RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist, 543 ArrayRef<int64_t> transposeVector, RankedTensorType transposedTensorType, 544 tensor::EmptyOp emptyOp, const HoistPaddingAnalysis &analysis) { 545 SmallVector<OpFoldResult> offsets, sizes, strides; 546 SmallVector<Value> clonedLoopIvs, leadingHoistedPackedTensorIndexings; 547 548 scf::ForOp outerLoop = analysis.outermostEnclosingForOp; 549 550 Location loc = opToHoist->getLoc(); 551 RankedTensorType paddedTensorType = opToHoist.getResultType(); 552 int paddedRank = paddedTensorType.getRank(); 553 554 // Step 0. Populate bvm with opToHoist.getSource if relevant. 555 BlockArgument bbArg = dyn_cast<BlockArgument>(opToHoist.getSource()); 556 while (bbArg) { 557 auto forOp = dyn_cast<scf::ForOp>(bbArg.getOwner()->getParentOp()); 558 if (!forOp) 559 break; 560 if (forOp != outerLoop && !outerLoop->isAncestor(forOp)) 561 break; 562 OpOperand &operand = *forOp.getTiedLoopInit(bbArg); 563 bvm.map(bbArg, operand.get()); 564 bbArg = dyn_cast<BlockArgument>(operand.get()); 565 } 566 567 // Step 1. iteratively clone loops and push `hoistedPackedTensor`. 568 Value hoistedPackedTensor = emptyOp.getResult(); 569 OpBuilder::InsertionGuard g(rewriter); 570 for (Operation *op : analysis.backwardSlice) { 571 // Specifically sit out in the extract_slice(hoistedPackedTensor) case: this 572 // is the piece we seek to replace. 573 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) { 574 if (bvm.lookupOrDefault(sliceOp.getSource()) == hoistedPackedTensor) { 575 LLVM_DEBUG(DBGS() << "--Skip: " << sliceOp << "\n"); 576 continue; 577 } 578 } 579 580 // Clone all operations except loops which require special handling. 581 auto forOp = dyn_cast<scf::ForOp>(op); 582 if (!forOp) { 583 // We are at the right insertion point within the loop nest. 584 rewriter.clone(*op, bvm); 585 continue; 586 } 587 588 // Create a packing loop that takes `hoistedPackedTensor` as iteration 589 // argument. 590 auto clonedForOp = rewriter.create<scf::ForOp>( 591 loc, bvm.lookupOrDefault(forOp.getLowerBound()), 592 bvm.lookupOrDefault(forOp.getUpperBound()), 593 bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor); 594 595 // Map the induction var, region args and results to the `clonedForOp`. 596 bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); 597 bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs()); 598 bvm.map(forOp.getResults(), clonedForOp.getResults()); 599 assert(clonedForOp->getNumRegions() == 1); 600 clonedLoopIvs.push_back(clonedForOp.getInductionVar()); 601 602 // Do not insert guard here, we get deeper into the loop nest. 603 rewriter.setInsertionPointToStart(&clonedForOp->getRegion(0).front()); 604 Value loopIndependentIterationCount = 605 buildLoopIterationCount(rewriter, outerLoop, clonedForOp); 606 607 // Assert the loop-independent iteration count can be computed. 608 if (!loopIndependentIterationCount) 609 llvm_unreachable("loop independence prerequisite not met"); 610 leadingHoistedPackedTensorIndexings.push_back( 611 loopIndependentIterationCount); 612 hoistedPackedTensor = clonedForOp.getRegionIterArgs().front(); 613 } 614 615 // Step 2. Construct offsets, sizes and strides for the innermost level of the 616 // packing loop. 617 int64_t nPackedLoops = clonedLoopIvs.size(); 618 // offsets = [clonedLoopIvs, 0 .. 0]. 619 offsets = 620 SmallVector<OpFoldResult>{leadingHoistedPackedTensorIndexings.begin(), 621 leadingHoistedPackedTensorIndexings.end()}; 622 offsets.append(paddedRank, rewriter.getIndexAttr(0)); 623 // sizes = [1 .. 1, transposedShape]. 624 sizes = SmallVector<OpFoldResult>(nPackedLoops, rewriter.getIndexAttr(1)); 625 for (int64_t sz : transposedTensorType.getShape()) { 626 // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor. 627 if (ShapedType::isDynamic(sz)) 628 return failure(); 629 sizes.push_back(rewriter.getIndexAttr(sz)); 630 } 631 // strides = [1 .. 1]. 632 strides = SmallVector<OpFoldResult>(nPackedLoops + paddedRank, 633 rewriter.getIndexAttr(1)); 634 635 // Step 3. Optionally transpose the padded tensor. 636 TransposeOp maybeTransposeOp; 637 Value paddedTensor = bvm.lookup(opToHoist.getResult()); 638 if (!transposeVector.empty()) { 639 Value outputTensor = rewriter.create<tensor::ExtractSliceOp>( 640 loc, transposedTensorType, hoistedPackedTensor, offsets, sizes, 641 strides); 642 maybeTransposeOp = rewriter.create<linalg::TransposeOp>( 643 loc, paddedTensor, outputTensor, transposeVector); 644 paddedTensor = maybeTransposeOp.getResult()[0]; 645 } 646 647 // Innermost tensor.insert_slice and yields are optional / need loops. 648 if (nPackedLoops > 0) { 649 // Step 4. Create InsertSliceOp at the innermost loop level, inserting an 650 // optionally transposed padded slice into the packed tensor. 651 Value inserted = rewriter.create<tensor::InsertSliceOp>( 652 loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides); 653 654 // Step 5. Iteratively pop the stack and propagate the yield. 655 Value valueToYield = inserted; 656 for (Value iv : llvm::reverse(clonedLoopIvs)) { 657 auto forOp = scf::getForInductionVarOwner(iv); 658 rewriter.setInsertionPointToEnd(&forOp.getRegion().front()); 659 rewriter.create<scf::YieldOp>(loc, valueToYield); 660 valueToYield = forOp.getResult(0); 661 } 662 } 663 664 return PackingResult{ 665 offsets, 666 sizes, 667 strides, 668 clonedLoopIvs, 669 leadingHoistedPackedTensorIndexings, 670 maybeTransposeOp, 671 cast<tensor::PadOp>(bvm.lookup(opToHoist.getResult()).getDefiningOp())}; 672 } 673 674 /// Build the packing loop nest required to hoist `opToHoist` above 675 /// `outermostEnclosingForOp`. 676 /// The loop nest is built just before `outermostEnclosingForOp`. 677 static FailureOr<PackingResult> buildPackingLoopNestImpl( 678 RewriterBase &rewriter, IRMapping &bvm, tensor::PadOp opToHoist, 679 ArrayRef<int64_t> transposeVector, const HoistPaddingAnalysis &analysis) { 680 // Update actual number of loops, which may be smaller. 681 int nPackedLoops = analysis.packingLoops.size(); 682 LLVM_DEBUG(DBGS() << "\n"; 683 DBGS() << "Func:\n" 684 << *opToHoist->getParentOfType<func::FuncOp>() << "\n"; 685 DBGS() << "Start hoisting above " << nPackedLoops << " loops\n"); 686 687 Location loc = opToHoist->getLoc(); 688 RankedTensorType paddedTensorType = opToHoist.getResultType(); 689 690 // Compute the type of the transposed padded tensor. 691 FailureOr<RankedTensorType> transposedTensorType = 692 tensor::computeTransposedType(paddedTensorType, transposeVector); 693 if (failed(transposedTensorType)) { 694 LLVM_DEBUG(DBGS() << "--Could not compute transposed type -> Skip\n"); 695 return failure(); 696 } 697 698 // Create the packed tensor<?x?x..? x transposedShape>. 699 SmallVector<int64_t> packedShape(nPackedLoops, ShapedType::kDynamic); 700 // TODO: go grab dims when needed, atm tensor::PadOp yields a static tensor. 701 llvm::append_range(packedShape, transposedTensorType->getShape()); 702 auto hoistedPackedTensorType = RankedTensorType::get( 703 packedShape, transposedTensorType->getElementType()); 704 705 // Set the insertion point right before the outer loop and start packing. 706 scf::ForOp outerLoop = analysis.outermostEnclosingForOp; 707 OpBuilder::InsertionGuard g(rewriter); 708 rewriter.setInsertionPoint(outerLoop); 709 SmallVector<Value> dynamicTensorSizes = 710 analysis.getHoistedPackedTensorSizes(rewriter, loc); 711 auto emptyOp = rewriter.create<tensor::EmptyOp>( 712 loc, hoistedPackedTensorType.getShape(), 713 hoistedPackedTensorType.getElementType(), dynamicTensorSizes); 714 715 return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector, 716 *transposedTensorType, emptyOp, analysis); 717 } 718 719 /// Build the packing loop nest required to hoist `opToHoist` above 720 /// `outermostEnclosingForOp`. 721 /// The loop nest is built just before `outermostEnclosingForOp`. 722 FailureOr<PackingResult> mlir::linalg::detail::buildPackingLoopNest( 723 RewriterBase &rewriter, tensor::PadOp opToHoist, 724 scf::ForOp outermostEnclosingForOp, ArrayRef<int64_t> transposeVector) { 725 HoistPaddingAnalysis analysis(opToHoist, outermostEnclosingForOp); 726 analysis.enableHoistPadding(rewriter); 727 analysis.finalizeHoistPaddingAnalysis(); 728 if (!analysis.isValid()) { 729 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n"); 730 return failure(); 731 } 732 IRMapping bvm; 733 return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector, 734 analysis); 735 } 736 737 //===----------------------------------------------------------------------===// 738 // hoistPaddingOnTensors Implementation. 739 //===----------------------------------------------------------------------===// 740 741 /// Return true if we can walk back the use-def chain from `extractSliceOp` to 742 /// expectedSource going through DestinationStyleOpInterface inits only. 743 /// This is a poor man's analysis that is sufficient to check the extractSliceOp 744 /// the matches tensor.pad we want to hoist. 745 /// In the future, it will be easier to ensure this with a matching symmetric 746 /// tensor.unpad op. 747 static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp, 748 Value expectedSource) { 749 LLVM_DEBUG(DBGS() << "Start tracesBackToExpectedValue on: " << extractSliceOp 750 << "\n"); 751 LLVM_DEBUG(DBGS() << "--with extractSlice: " << extractSliceOp << "\n"); 752 Value source = extractSliceOp.getSource(); 753 LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n"); 754 while (source && source != expectedSource) { 755 auto destOp = 756 dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp()); 757 if (!destOp) 758 break; 759 LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n"); 760 source = destOp.getDpsInitOperand(cast<OpResult>(source).getResultNumber()) 761 ->get(); 762 } 763 LLVM_DEBUG(DBGS() << "--final source: " << source << "\n"); 764 LLVM_DEBUG(DBGS() << "--expected source: " << expectedSource << "\n"); 765 return source == expectedSource; 766 } 767 768 /// If the original consumer of `outerSliceOp` was a `forOp` (i.e. through an 769 /// iter arg), propagate the `hoistedPackedTensor` value through the same iter 770 /// arg. 771 /// TODO: for multiple loops we need to track the use to the innermost loop. 772 /// 773 /// Match: 774 /// ``` 775 /// %outerSliceOp = tensor.extract_slice .. 776 /// %f = scf.for ... iter_args(%arg0 = %outerSliceOp) { 777 /// %hoistedPackedTensor = tensor.pad %arg0 778 /// %1 = compute %hoistedPackedTensor 779 /// %2 = tensor.extract_slice %1 780 /// scf.yield %2 781 /// } 782 /// ``` 783 /// 784 /// and rewrite as: 785 /// ``` 786 /// %outerSliceOp = tensor.extract_slice .. 787 /// %hoistedPackedTensor = tensor.pad %outerSliceOp 788 /// %f = scf.for ... iter_args(%arg0 = %hoistedPackedTensor) { 789 /// %1 = compute %arg0 790 /// scf.yield %1 791 /// } 792 /// %2 = tensor.extract_slice %forOp 793 /// ``` 794 /// 795 /// Return null when no rewrite happened. 796 static tensor::ExtractSliceOp 797 padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, 798 Value hoistedPackedTensor, 799 tensor::ExtractSliceOp outerSliceOp, scf::ForOp forOp) { 800 LLVM_DEBUG(DBGS() << "Start padThroughLoopIterArg on: " << forOp << "\n"); 801 LLVM_DEBUG(DBGS() << "--paddedValueBeforeHoisting: " 802 << paddedValueBeforeHoisting << "\n"); 803 OpOperand *pUse = nullptr; 804 for (OpOperand &use : outerSliceOp->getUses()) { 805 if (use.getOwner() == forOp) { 806 assert(!pUse && "Multiple slice uses in the for loop"); 807 pUse = &use; 808 } 809 } 810 assert(pUse && "No slice use in the for loop"); 811 OpBuilder::InsertionGuard g(rewriter); 812 rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp()); 813 814 unsigned iterArgNumber = forOp.getTiedLoopResult(pUse).getResultNumber(); 815 auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber] 816 .getDefiningOp<tensor::ExtractSliceOp>(); 817 if (!yieldingExtractSliceOp) 818 return tensor::ExtractSliceOp(); 819 820 // Poor man's analysis sufficient to ensure extractSlice matches tensor.pad. 821 // In the future, it will be easier to ensure this with a matching symmetric 822 // tensor.unpad op. 823 if (!tracesBackToExpectedValue(yieldingExtractSliceOp, 824 paddedValueBeforeHoisting)) 825 return tensor::ExtractSliceOp(); 826 827 SmallVector<Value> initArgs = forOp.getInitArgs(); 828 initArgs[iterArgNumber] = hoistedPackedTensor; 829 SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues()); 830 yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource(); 831 832 int64_t numOriginalForOpResults = initArgs.size(); 833 LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults 834 << "\n"); 835 tensor::ExtractSliceOp extracted; 836 { 837 OpBuilder::InsertionGuard g(rewriter); 838 rewriter.setInsertionPointAfter(forOp); 839 extracted = rewriter.create<tensor::ExtractSliceOp>( 840 hoistedPackedTensor.getLoc(), hoistedPackedTensor, 841 outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(), 842 outerSliceOp.getMixedStrides()); 843 rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted); 844 } 845 scf::ForOp newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields( 846 rewriter, initArgs, /*replaceInitOperandUsesInLoop=*/true, 847 [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) { 848 return yieldOperands; 849 })); 850 851 LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults() 852 << "\n"); 853 LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n"); 854 LLVM_DEBUG(DBGS() << "with result #" 855 << numOriginalForOpResults + iterArgNumber 856 << " of forOp, giving us: " << extracted << "\n"); 857 rewriter.startOpModification(extracted); 858 extracted.getSourceMutable().assign( 859 newForOp.getResult(numOriginalForOpResults + iterArgNumber)); 860 rewriter.finalizeOpModification(extracted); 861 862 LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting 863 << "\n"); 864 LLVM_DEBUG(DBGS() << "with region iter arg #" 865 << numOriginalForOpResults + iterArgNumber << "\n"); 866 rewriter.replaceAllUsesWith( 867 paddedValueBeforeHoisting, 868 newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber)); 869 870 return extracted; 871 } 872 873 /// Produce a tensor extracted from the packingResult. This can be used as a 874 /// replacement for `opToHoist` in callers. 875 static Value replaceByPackingResult(RewriterBase &rewriter, 876 const IRMapping &bvm, 877 tensor::PadOp opToHoist, 878 RankedTensorType transposedTensorType, 879 const HoistPaddingAnalysis &analysis, 880 const PackingResult &packingResult) { 881 // The replacement occurs under a single insertion point within the original 882 // loop, just before opToHoist. 883 OpBuilder::InsertionGuard g(rewriter); 884 rewriter.setInsertionPoint(opToHoist); 885 886 Location loc = opToHoist->getLoc(); 887 RankedTensorType paddedTensorType = opToHoist.getResultType(); 888 int paddedRank = paddedTensorType.getRank(); 889 890 int64_t nPackedLoops = packingResult.clonedLoopIvs.size(); 891 LLVM_DEBUG(DBGS() << "nPackedLoops: " << nPackedLoops << " loops\n"); 892 893 scf::ForOp outerLoop = analysis.outermostEnclosingForOp; 894 ArrayRef<scf::ForOp> packingLoops = analysis.packingLoops; 895 896 Value hoistedPackedTensor; 897 SmallVector<Value> loopIterationCounts; 898 SmallVector<OpFoldResult> offsets(nPackedLoops + paddedRank, 899 rewriter.getIndexAttr(0)); 900 if (nPackedLoops > 0) { 901 loopIterationCounts = 902 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) { 903 return buildLoopIterationCount(rewriter, outerLoop, 904 cast<scf::ForOp>(loop)); 905 })); 906 // Assert all loop iteration counts can be computed. 907 if (llvm ::any_of(loopIterationCounts, [](Value v) { return !v; })) 908 llvm_unreachable("loop independence prerequisite not met"); 909 910 // offsets = [maybe_leading_ivs = originalLoopIvs, 0 .. 0]. 911 std::copy(loopIterationCounts.begin(), loopIterationCounts.end(), 912 offsets.begin()); 913 hoistedPackedTensor = 914 scf::getForInductionVarOwner(packingResult.clonedLoopIvs.front()) 915 ->getResult(0); 916 } else { 917 // If no loops were created, this is just hoisting without packing. 918 hoistedPackedTensor = bvm.lookup(opToHoist.getResult()); 919 } 920 921 LLVM_DEBUG(DBGS() << "hoistedPackedTensor: " << hoistedPackedTensor << "\n"); 922 923 // If the consumer of `padOp` was a `forOp`, propagate through iter args. 924 scf::ForOp forOp = analysis.padConsumingForOp; 925 if (forOp) { 926 return padThroughLoopIterArg(rewriter, opToHoist, hoistedPackedTensor, 927 analysis.sliceOp, forOp); 928 } 929 930 // offsets = [maybe_leading_ivs, 0 .. 0]. 931 // sizes = [1 .. 1, transposedShape] (defined above). 932 // strides = [1 .. 1] (defined above) 933 return rewriter.create<tensor::ExtractSliceOp>( 934 loc, transposedTensorType, hoistedPackedTensor, offsets, 935 packingResult.sizes, packingResult.strides); 936 } 937 938 FailureOr<Value> mlir::linalg::hoistPaddingOnTensors( 939 RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, 940 ArrayRef<int64_t> transposeVector, tensor::PadOp &hoistedOp, 941 SmallVectorImpl<TransposeOp> &transposeOps) { 942 LLVM_DEBUG(DBGS() << "\n"; DBGS() << " Try to hoist " << *(opToHoist) << "\n"; 943 DBGS() << " by " << numLoops << " loops\n"); 944 945 HoistPaddingAnalysis analysis(opToHoist, numLoops); 946 analysis.enableHoistPadding(rewriter); 947 analysis.finalizeHoistPaddingAnalysis(); 948 if (!analysis.isValid()) { 949 LLVM_DEBUG(DBGS() << "--Analysis failed -> Skip\n"); 950 return failure(); 951 } 952 953 /// Construct the packing loop nest. 954 IRMapping bvm; 955 FailureOr<PackingResult> packingResult = buildPackingLoopNestImpl( 956 rewriter, bvm, opToHoist, transposeVector, analysis); 957 if (failed(packingResult)) { 958 LLVM_DEBUG(DBGS() << "--buildPackingLoopNestImpl failed -> Skip\n"); 959 return failure(); 960 } 961 962 if (!transposeVector.empty()) 963 transposeOps.push_back(packingResult->maybeTransposeOp); 964 965 FailureOr<RankedTensorType> transposedTensorType = 966 tensor::computeTransposedType(opToHoist.getResultType(), transposeVector); 967 assert(succeeded(transposedTensorType) && "unexpected failure in type"); 968 969 // Now the packed tensor is ready, replace the original padding op by a 970 // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1]. 971 Value newResult = 972 replaceByPackingResult(rewriter, bvm, opToHoist, *transposedTensorType, 973 analysis, *packingResult); 974 975 Location loc = opToHoist->getLoc(); 976 RankedTensorType paddedTensorType = opToHoist.getResultType(); 977 if (!transposeVector.empty()) { 978 OpBuilder::InsertionGuard g(rewriter); 979 rewriter.setInsertionPointAfter(newResult.getDefiningOp()); 980 // Transpose the packed tensor back to the original storage order. 981 Value emptyTensor = rewriter.create<tensor::EmptyOp>( 982 loc, paddedTensorType.getShape(), paddedTensorType.getElementType()); 983 TransposeOp unTransposeOp = rewriter.create<linalg::TransposeOp>( 984 loc, newResult, emptyTensor, transposeVector); 985 newResult = unTransposeOp.getResult()[0]; 986 transposeOps.push_back(unTransposeOp); 987 } 988 989 LLVM_DEBUG(DBGS() << "newResult: " << newResult << "\n"); 990 LLVM_DEBUG( 991 DBGS() << "After hoisting: " 992 << newResult.getDefiningOp()->getParentOfType<func::FuncOp>() 993 << "\n"); 994 995 // Make the newly cloned `opToHoist` available to the caller. 996 hoistedOp = packingResult->hoistedPadOp; 997 998 LLVM_DEBUG(DBGS() << "--SUCCESS\n"); 999 return newResult; 1000 } 1001 1002 FailureOr<Value> mlir::linalg::hoistPaddingOnTensors( 1003 tensor::PadOp opToHoist, int64_t numLoops, 1004 ArrayRef<int64_t> transposeVector, tensor::PadOp &hoistedOp, 1005 SmallVectorImpl<TransposeOp> &transposeOps) { 1006 IRRewriter rewriter(opToHoist.getContext()); 1007 return hoistPaddingOnTensors(rewriter, opToHoist, numLoops, transposeVector, 1008 hoistedOp, transposeOps); 1009 } 1010