1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Interfaces/TilingInterface.h" 24 #include "llvm/Support/Debug.h" 25 26 #define DEBUG_TYPE "tile-using-interface" 27 28 using namespace mlir; 29 30 scf::SCFTilingOptions & 31 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 32 assert(!tileSizeComputationFunction && "tile sizes already set"); 33 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 34 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 35 OpBuilder::InsertionGuard guard(b); 36 b.setInsertionPointToStart( 37 &op->getParentOfType<func::FuncOp>().getBody().front()); 38 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 39 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 40 return v; 41 })); 42 }; 43 return *this; 44 } 45 46 /// Helper method to adjust the interchange vector to match the iteration 47 /// domain. 48 static SmallVector<unsigned> 49 fillInterchangeVector(ArrayRef<unsigned> interchangeVector, 50 size_t iterationDomainSize) { 51 SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector); 52 if (filledVector.size() < iterationDomainSize) { 53 auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize); 54 filledVector.append(range.begin(), range.end()); 55 } 56 if (filledVector.size() > iterationDomainSize) 57 filledVector.resize(iterationDomainSize); 58 return filledVector; 59 } 60 61 /// Helper method to apply permutation to a vector 62 template <typename T> 63 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector, 64 ArrayRef<unsigned> interchange) { 65 assert(interchange.size() == vector.size()); 66 return llvm::to_vector( 67 llvm::map_range(interchange, [&](unsigned val) { return vector[val]; })); 68 } 69 /// Helper method to apply to invert a permutation. 70 static SmallVector<unsigned> 71 invertPermutationVector(ArrayRef<unsigned> interchange) { 72 SmallVector<unsigned> inversion(interchange.size()); 73 for (auto pos : llvm::enumerate(interchange)) { 74 inversion[pos.value()] = pos.index(); 75 } 76 return inversion; 77 } 78 /// Method to check if an interchange vector is a permutation. 79 static bool isPermutation(ArrayRef<unsigned> interchange) { 80 llvm::SmallDenseSet<unsigned, 4> seenVals; 81 for (auto val : interchange) { 82 if (seenVals.count(val)) 83 return false; 84 seenVals.insert(val); 85 } 86 return seenVals.size() == interchange.size(); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // TileUsingSCFForOp pattern implementation. 91 //===----------------------------------------------------------------------===// 92 93 // Check if `stride` evenly divides the trip count `size - offset`. 94 static bool tileDividesIterationDomain(Range loopRange) { 95 Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); 96 if (!offsetAsInt) 97 return false; 98 Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); 99 if (!sizeAsInt) 100 return false; 101 Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); 102 if (!strideAsInt) 103 return false; 104 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); 105 } 106 107 /// Generate an empty loop nest that represents the tiled loop nest shell. 108 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 109 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 110 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 111 /// the 112 /// tile processed within the inner most loop. 113 static SmallVector<scf::ForOp> 114 generateTileLoopNest(OpBuilder &builder, Location loc, 115 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 116 SmallVector<OpFoldResult> &offsets, 117 SmallVector<OpFoldResult> &sizes) { 118 assert(!loopRanges.empty() && "expected at least one loop range"); 119 assert(loopRanges.size() == tileSizeVals.size() && 120 "expected as many tile sizes as loop ranges"); 121 OpBuilder::InsertionGuard guard(builder); 122 SmallVector<scf::ForOp> loops; 123 offsets.resize(loopRanges.size()); 124 sizes.resize(loopRanges.size()); 125 126 // The tile size to use (to avoid out of bounds access) is minimum of 127 // `tileSize` and `ub - iv`, where `iv` is the induction variable 128 // of the tiled loop. 129 AffineExpr s0, s1, d0; 130 bindDims(builder.getContext(), d0); 131 bindSymbols(builder.getContext(), s0, s1); 132 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 133 134 for (auto loopRange : llvm::enumerate(loopRanges)) { 135 Value offset = 136 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 137 Value size = 138 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 139 // No loops if tile size is zero. Set offset and size to the loop 140 // offset and size. 141 if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 142 offsets[loopRange.index()] = offset; 143 sizes[loopRange.index()] = size; 144 continue; 145 } 146 147 auto loop = builder.create<scf::ForOp>( 148 loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{}, 149 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 150 ValueRange /*iterArgs*/) { 151 bool canAvoidMap = tileDividesIterationDomain( 152 Range{loopRange.value().offset, loopRange.value().size, 153 tileSizeVals[loopRange.index()]}); 154 Value boundedTileSize = 155 (canAvoidMap) 156 ? tileSizeVals[loopRange.index()] 157 : builder.create<AffineMinOp>( 158 bodyLoc, minMap, 159 ValueRange{iv, tileSizeVals[loopRange.index()], size}); 160 sizes[loopRange.index()] = boundedTileSize; 161 builder.create<scf::YieldOp>(loc); 162 }); 163 offsets[loopRange.index()] = loop.getInductionVar(); 164 loops.push_back(loop); 165 builder.setInsertionPoint(loop.getBody()->getTerminator()); 166 } 167 return loops; 168 } 169 170 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, 171 scf::SCFTilingOptions options, 172 PatternBenefit benefit) 173 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 174 options(std::move(options)) {} 175 176 scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, 177 MLIRContext *context, 178 scf::SCFTilingOptions options, 179 PatternBenefit benefit) 180 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 181 options(std::move(options)) {} 182 183 FailureOr<scf::SCFTilingResult> 184 scf::TileUsingSCFForOp::returningMatchAndRewrite( 185 TilingInterface op, PatternRewriter &rewriter) const { 186 OpBuilder::InsertionGuard guard(rewriter); 187 rewriter.setInsertionPointAfter(op); 188 189 if (!options.tileSizeComputationFunction) { 190 return rewriter.notifyMatchFailure( 191 op, "missing tile size computation function"); 192 } 193 194 // 1. Get the range of the loops that are represented by the operation. 195 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 196 size_t numLoops = iterationDomain.size(); 197 if (numLoops == 0) { 198 return rewriter.notifyMatchFailure( 199 op, "unable to tile op with no iteration domain"); 200 } 201 202 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 203 // skips tiling a particular dimension. This convention is significantly 204 // simpler to handle instead of adjusting affine maps to account for missing 205 // dimensions. 206 SmallVector<Value> tileSizeVector = 207 options.tileSizeComputationFunction(rewriter, op); 208 if (tileSizeVector.size() < iterationDomain.size()) { 209 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 210 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 211 } 212 213 scf::SCFTilingResult tilingResult; 214 SmallVector<OpFoldResult> offsets, sizes; 215 { 216 // If there is an interchange specified, permute the iteration domain and 217 // the tile sizes. 218 SmallVector<unsigned> interchangeVector; 219 if (!options.interchangeVector.empty()) { 220 interchangeVector = fillInterchangeVector(options.interchangeVector, 221 iterationDomain.size()); 222 } 223 if (!interchangeVector.empty()) { 224 if (!isPermutation(interchangeVector)) { 225 return rewriter.notifyMatchFailure( 226 op, "invalid intechange vector, not a permutation of the entire " 227 "iteration space"); 228 } 229 230 iterationDomain = 231 applyPermutationToVector(iterationDomain, interchangeVector); 232 tileSizeVector = 233 applyPermutationToVector(tileSizeVector, interchangeVector); 234 } 235 236 // 3. Materialize an empty loop nest that iterates over the tiles. These 237 // loops for now do not return any values even if the original operation has 238 // results. 239 tilingResult.loops = generateTileLoopNest( 240 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 241 242 if (!interchangeVector.empty()) { 243 auto inversePermutation = invertPermutationVector(interchangeVector); 244 offsets = applyPermutationToVector(offsets, inversePermutation); 245 sizes = applyPermutationToVector(sizes, inversePermutation); 246 } 247 248 LLVM_DEBUG({ 249 if (!tilingResult.loops.empty()) { 250 llvm::errs() << "LoopNest shell :\n"; 251 tilingResult.loops.front().dump(); 252 llvm::errs() << "\n"; 253 } 254 }); 255 256 // 4. Generate the tiled implementation within the inner most loop. 257 if (!tilingResult.loops.empty()) 258 rewriter.setInsertionPoint( 259 tilingResult.loops.back().getBody()->getTerminator()); 260 SmallVector<Operation *> tiledImplementation = 261 op.getTiledImplementation(rewriter, offsets, sizes); 262 if (tiledImplementation.size() != 1) { 263 return rewriter.notifyMatchFailure( 264 op, "expected tiled implementation to return a single op"); 265 } 266 tilingResult.tiledOp = tiledImplementation[0]; 267 268 LLVM_DEBUG({ 269 if (!tilingResult.loops.empty()) { 270 llvm::errs() << "After tiled implementation :\n"; 271 tilingResult.loops.front().dump(); 272 llvm::errs() << "\n"; 273 } 274 }); 275 } 276 277 if (op->getNumResults() == 0) { 278 rewriter.eraseOp(op); 279 return tilingResult; 280 } 281 282 // 5. If the original operations has results, modify the loop nest to yield 283 // the replacement values. 284 SmallVector<Value> replacements; 285 if (tilingResult.loops.empty()) { 286 // 5a. If there were no loops, the tiled implementation results are the 287 // replacements. 288 rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); 289 return tilingResult; 290 } 291 292 // 5b. `scf.for` with tensor semantics requires the loop nest to yield the 293 // replacement values using destructive updates. Use the `TilingInterface` 294 // to get the position of the result tiles and use that to generate the 295 // destructive update pattern, i.e., 296 // 297 // ```mlir 298 // scf.for %iv0 = ... { 299 // %0 = tiled_op 300 // } 301 // ``` 302 // 303 // is transformed to 304 // 305 // ```mlir 306 // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { 307 // %0 = tiled_op 308 // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] 309 // scf.yield %1 310 // } 311 // ``` 312 NewYieldValueFn yieldValueFn = 313 [&](OpBuilder &b, Location loc, 314 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 315 SmallVector<Value> yieldedValues; 316 Attribute one = b.getIndexAttr(1); 317 for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) { 318 SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes; 319 if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, 320 resultTileOffsets, 321 resultTileSizes))) { 322 op.emitOpError("unable to get position of result ") 323 << resultNum << " of the tiled implementation"; 324 return {}; 325 } 326 SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(), 327 one); 328 Value yieldedValue = b.create<tensor::InsertSliceOp>( 329 op->getLoc(), tilingResult.tiledOp->getResult(resultNum), 330 newBBArgs[resultNum], resultTileOffsets, resultTileSizes, 331 resultTileStrides); 332 yieldedValues.push_back(yieldedValue); 333 } 334 return yieldedValues; 335 }; 336 SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields( 337 rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), 338 yieldValueFn); 339 for (const auto &loop : llvm::enumerate(tilingResult.loops)) { 340 rewriter.eraseOp(loop.value()); 341 tilingResult.loops[loop.index()] = newLoops[loop.index()]; 342 } 343 rewriter.replaceOp(op, tilingResult.loops.front().getResults()); 344 return tilingResult; 345 } 346 347 //===----------------------------------------------------------------------===// 348 // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. 349 //===----------------------------------------------------------------------===// 350 351 scf::TileConsumerAndFuseProducersUsingSCFForOp:: 352 TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, 353 scf::SCFTilingOptions options, 354 PatternBenefit benefit) 355 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 356 tilingPattern(context, std::move(options)) {} 357 358 scf::TileConsumerAndFuseProducersUsingSCFForOp:: 359 TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, 360 MLIRContext *context, 361 scf::SCFTilingOptions options, 362 PatternBenefit benefit) 363 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 364 tilingPattern(context, std::move(options)) {} 365 366 /// Return the `Value` that is defined by an operation that implements 367 /// the `TilingInterface`. Looks through `iter_args` of scf.for nest 368 /// if required. 369 static Optional<OpResult> getFusableProducer(Value v) { 370 while (auto blockArg = v.dyn_cast<BlockArgument>()) { 371 auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp()); 372 if (!loopOp) 373 return llvm::None; 374 v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); 375 } 376 if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp())) 377 return llvm::None; 378 return v.cast<OpResult>(); 379 } 380 381 // Replace iter args of the outer most loop with region args of the inner most 382 // one. 383 static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, 384 PatternRewriter &rewriter) { 385 assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && 386 "expect same number of iter args"); 387 Block *block = &(*innerFor.getRegion().begin()); 388 for (auto it : 389 llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) { 390 Value source = std::get<0>(it); 391 Value target = std::get<1>(it); 392 source.replaceUsesWithIf(target, [&](OpOperand &use) { 393 return use.getOwner()->getBlock() == block; 394 }); 395 } 396 } 397 398 FailureOr<scf::SCFTileAndFuseResult> 399 scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( 400 TilingInterface op, PatternRewriter &rewriter) const { 401 // This transformation is only valid for ops that return values (i.e. not 402 // valid to use with operations that have memref operands). 403 if (!op->getNumResults()) { 404 return rewriter.notifyMatchFailure( 405 op, "invalid pattern for op with no results"); 406 } 407 408 // 1. First tile the consumer. 409 SCFTileAndFuseResult tileAndFuseResult; 410 { 411 FailureOr<SCFTilingResult> tilingResult = 412 tilingPattern.returningMatchAndRewrite(op, rewriter); 413 if (failed(tilingResult)) { 414 return failure(); 415 } 416 tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); 417 tileAndFuseResult.loops = std::move(tilingResult->loops); 418 } 419 420 // 2. Typically, the operands of the tiled operation are slices of the 421 // operands of the untiled operation. These are expressed in IR using 422 // `tensor.extract_slice` operations with source being the operands of the 423 // untiled operation. Create a worklist of these `tensor.extract_slice` 424 // operations. If the producers of the source of the `tensor.extract_slice` 425 // can be tiled such that the tiled value is generated in-place, that 426 // effectively tiles + fuses the operations. 427 auto addCandidateSlices = [](Operation *fusedOp, 428 std::deque<tensor::ExtractSliceOp> &candidates) { 429 for (Value operand : fusedOp->getOperands()) 430 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 431 candidates.push_back(sliceOp); 432 }; 433 434 std::deque<tensor::ExtractSliceOp> candidates; 435 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 436 OpBuilder::InsertionGuard g(rewriter); 437 while (!candidates.empty()) { 438 // 2a. Traverse the slices in BFS fashion. 439 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 440 candidates.pop_front(); 441 442 // 2b. Get the producer of the source (potentially walking through 443 // `iter_args` of nested `scf.for`) 444 Optional<OpResult> fusableProducer = 445 getFusableProducer(candidateSliceOp.getSource()); 446 if (!fusableProducer) 447 continue; 448 449 // 2c. Generate the tiled implementation of the producer of the source 450 rewriter.setInsertionPoint(candidateSliceOp); 451 FailureOr<Value> fusedProducerValue = 452 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 453 fusableProducer.value()); 454 if (failed(fusedProducerValue)) 455 continue; 456 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); 457 458 // 2d. The operands of the fused producer might themselved be slices of 459 // values produced by operations that implement the `TilingInterface`. 460 // Add these operations to the worklist. 461 Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 462 tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); 463 addCandidateSlices(fusedProducer, candidates); 464 465 // 2e. If the operation being fused creates a value that is used as `outs` 466 // in the tiled operation, the result of the unfused operation will be 467 // used in the `iter_args` of the tiled loop generated. When the 468 // operation is fused, this use in `iter_args` needs to be modified to 469 // use the destination of the fused operation. For example, starting 470 // with 471 // 472 // ```mlir 473 // %0 = linalg.init_tensor ... 474 // %1 = linalg.fill ... outs(%0:...)... 475 // %2 = linalg.matmul ... outs(%1:...).... 476 // ``` 477 // 478 // First the `linalg.matmul` gets tiled 479 // 480 // ```mlir 481 // %0 = linalg.init_tensor 482 // %1 = linalg.fill 483 // %2 = scf.for .... iter_args(%arg0 = %1)... 484 // ... 485 // ... = linalg.matmul ... 486 // 487 // ``` 488 // 489 // When the `linalg.fill` gets fused, the `iter_args` needs to be 490 // modified 491 // 492 // ```mlir 493 // %0 = linalg.init_tensor 494 // %1 = scf.for ... iter_args(%arg0 = %0)... 495 // ... 496 // %2 = linalg.fill ... 497 // %3 = linalg.matmul ... outs(%2: ...)... 498 // ``` 499 TilingInterface unfusedProducerOp = 500 cast<TilingInterface>(fusableProducer->getOwner()); 501 scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); 502 SmallVector<Value> unfusedProducerOpDestValues = 503 unfusedProducerOp.getDestinationOperands(rewriter); 504 for (OpOperand &uses : unfusedProducerOp->getUses()) { 505 if (uses.getOwner() == outerMostTiledLoop.getOperation()) { 506 unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber(); 507 unsigned operandNumber = uses.getOperandNumber(); 508 outerMostTiledLoop->setOperand( 509 operandNumber, unfusedProducerOpDestValues[resultNumber]); 510 } 511 } 512 } 513 replaceIterArgs(tileAndFuseResult.loops.front(), 514 tileAndFuseResult.loops.back(), rewriter); 515 return tileAndFuseResult; 516 } 517 518 //===----------------------------------------------------------------------===// 519 // LowerToLoopsUsingSCFForOp 520 //===----------------------------------------------------------------------===// 521 522 FailureOr<SmallVector<scf::ForOp>> 523 scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite( 524 TilingInterface op, PatternRewriter &rewriter) const { 525 SmallVector<Range> domain = op.getIterationDomain(rewriter); 526 527 // TODO: Handle cases where the op has results if needed. 528 if (op->getNumResults() > 0) { 529 return rewriter.notifyMatchFailure( 530 op, "unable to lower to loops operations with return values"); 531 } 532 533 SmallVector<Value> ivs; 534 SmallVector<scf::ForOp> loops; 535 Location loc = op.getLoc(); 536 for (auto loopRange : domain) { 537 Value offsetVal = 538 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 539 Value sizeVal = 540 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 541 Value strideVal = 542 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 543 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 544 strideVal, ValueRange{}); 545 loops.push_back(loop); 546 ivs.push_back(loop.getInductionVar()); 547 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 548 } 549 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 550 return failure(); 551 } 552 rewriter.eraseOp(op); 553 return loops; 554 } 555