1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Arithmetic/Utils/Utils.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/SCF/Utils/Utils.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/IR/Matchers.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Interfaces/TilingInterface.h" 24 #include "llvm/Support/Debug.h" 25 26 #define DEBUG_TYPE "tile-using-interface" 27 28 using namespace mlir; 29 30 scf::SCFTilingOptions & 31 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 32 assert(!tileSizeComputationFunction && "tile sizes already set"); 33 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 34 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 35 OpBuilder::InsertionGuard guard(b); 36 b.setInsertionPointToStart( 37 &op->getParentOfType<func::FuncOp>().getBody().front()); 38 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 39 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 40 return v; 41 })); 42 }; 43 return *this; 44 } 45 46 /// Helper method to adjust the interchange vector to match the iteration 47 /// domain. 48 static SmallVector<unsigned> 49 fillInterchangeVector(ArrayRef<unsigned> interchangeVector, 50 size_t iterationDomainSize) { 51 SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector); 52 if (filledVector.size() < iterationDomainSize) { 53 auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize); 54 filledVector.append(range.begin(), range.end()); 55 } 56 if (filledVector.size() > iterationDomainSize) 57 filledVector.resize(iterationDomainSize); 58 return filledVector; 59 } 60 61 /// Helper method to apply permutation to a vector 62 template <typename T> 63 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector, 64 ArrayRef<unsigned> interchange) { 65 assert(interchange.size() == vector.size()); 66 return llvm::to_vector( 67 llvm::map_range(interchange, [&](unsigned val) { return vector[val]; })); 68 } 69 /// Helper method to apply to invert a permutation. 70 static SmallVector<unsigned> 71 invertPermutationVector(ArrayRef<unsigned> interchange) { 72 SmallVector<unsigned> inversion(interchange.size()); 73 for (auto pos : llvm::enumerate(interchange)) { 74 inversion[pos.value()] = pos.index(); 75 } 76 return inversion; 77 } 78 /// Method to check if an interchange vector is a permutation. 79 static bool isPermutation(ArrayRef<unsigned> interchange) { 80 llvm::SmallDenseSet<unsigned, 4> seenVals; 81 for (auto val : interchange) { 82 if (seenVals.count(val)) 83 return false; 84 seenVals.insert(val); 85 } 86 return seenVals.size() == interchange.size(); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // TileUsingSCFForOp pattern implementation. 91 //===----------------------------------------------------------------------===// 92 93 /// Generate an empty loop nest that represents the tiled loop nest shell. 94 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 95 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 96 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 97 /// the 98 /// tile processed within the inner most loop. 99 static SmallVector<scf::ForOp> 100 generateTileLoopNest(OpBuilder &builder, Location loc, 101 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 102 SmallVector<OpFoldResult> &offsets, 103 SmallVector<OpFoldResult> &sizes) { 104 assert(!loopRanges.empty() && "expected at least one loop range"); 105 assert(loopRanges.size() == tileSizeVals.size() && 106 "expected as many tile sizes as loop ranges"); 107 OpBuilder::InsertionGuard guard(builder); 108 SmallVector<scf::ForOp> loops; 109 offsets.resize(loopRanges.size()); 110 sizes.resize(loopRanges.size()); 111 112 // The tile size to use (to avoid out of bounds access) is minimum of 113 // `tileSize` and `ub - iv`, where `iv` is the induction variable 114 // of the tiled loop. 115 AffineExpr s0, s1, d0; 116 bindDims(builder.getContext(), d0); 117 bindSymbols(builder.getContext(), s0, s1); 118 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 119 120 for (auto loopRange : llvm::enumerate(loopRanges)) { 121 Value offset = 122 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); 123 Value size = 124 getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); 125 // No loops if tile size is zero. Set offset and size to the loop 126 // offset and size. 127 if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 128 offsets[loopRange.index()] = offset; 129 sizes[loopRange.index()] = size; 130 continue; 131 } 132 133 auto loop = builder.create<scf::ForOp>( 134 loc, offset, size, tileSizeVals[loopRange.index()], ValueRange{}, 135 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 136 ValueRange /*iterArgs*/) { 137 Value boundedTileSize = builder.create<AffineMinOp>( 138 bodyLoc, minMap, 139 ValueRange{iv, tileSizeVals[loopRange.index()], size}); 140 sizes[loopRange.index()] = boundedTileSize; 141 builder.create<scf::YieldOp>(loc); 142 }); 143 offsets[loopRange.index()] = loop.getInductionVar(); 144 loops.push_back(loop); 145 builder.setInsertionPoint(loop.getBody()->getTerminator()); 146 } 147 return loops; 148 } 149 150 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, 151 scf::SCFTilingOptions options, 152 PatternBenefit benefit) 153 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 154 options(std::move(options)) {} 155 156 scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, 157 MLIRContext *context, 158 scf::SCFTilingOptions options, 159 PatternBenefit benefit) 160 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 161 options(std::move(options)) {} 162 163 FailureOr<scf::SCFTilingResult> 164 scf::TileUsingSCFForOp::returningMatchAndRewrite( 165 TilingInterface op, PatternRewriter &rewriter) const { 166 OpBuilder::InsertionGuard guard(rewriter); 167 rewriter.setInsertionPointAfter(op); 168 169 if (!options.tileSizeComputationFunction) { 170 return rewriter.notifyMatchFailure( 171 op, "missing tile size computation function"); 172 } 173 174 // 1. Get the range of the loops that are represented by the operation. 175 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 176 size_t numLoops = iterationDomain.size(); 177 if (numLoops == 0) { 178 return rewriter.notifyMatchFailure( 179 op, "unable to tile op with no iteration domain"); 180 } 181 182 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 183 // skips tiling a particular dimension. This convention is significantly 184 // simpler to handle instead of adjusting affine maps to account for missing 185 // dimensions. 186 SmallVector<Value> tileSizeVector = 187 options.tileSizeComputationFunction(rewriter, op); 188 if (tileSizeVector.size() < iterationDomain.size()) { 189 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 190 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 191 } 192 193 scf::SCFTilingResult tilingResult; 194 SmallVector<OpFoldResult> offsets, sizes; 195 { 196 // If there is an interchange specified, permute the iteration domain and 197 // the tile sizes. 198 SmallVector<unsigned> interchangeVector; 199 if (!options.interchangeVector.empty()) { 200 interchangeVector = fillInterchangeVector(options.interchangeVector, 201 iterationDomain.size()); 202 } 203 if (!interchangeVector.empty()) { 204 if (!isPermutation(interchangeVector)) { 205 return rewriter.notifyMatchFailure( 206 op, "invalid intechange vector, not a permutation of the entire " 207 "iteration space"); 208 } 209 210 iterationDomain = 211 applyPermutationToVector(iterationDomain, interchangeVector); 212 tileSizeVector = 213 applyPermutationToVector(tileSizeVector, interchangeVector); 214 } 215 216 // 3. Materialize an empty loop nest that iterates over the tiles. These 217 // loops for now do not return any values even if the original operation has 218 // results. 219 tilingResult.loops = generateTileLoopNest( 220 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 221 222 if (!interchangeVector.empty()) { 223 auto inversePermutation = invertPermutationVector(interchangeVector); 224 offsets = applyPermutationToVector(offsets, inversePermutation); 225 sizes = applyPermutationToVector(sizes, inversePermutation); 226 } 227 228 LLVM_DEBUG({ 229 if (!tilingResult.loops.empty()) { 230 llvm::errs() << "LoopNest shell :\n"; 231 tilingResult.loops.front().dump(); 232 llvm::errs() << "\n"; 233 } 234 }); 235 236 // 4. Generate the tiled implementation within the inner most loop. 237 if (!tilingResult.loops.empty()) 238 rewriter.setInsertionPoint( 239 tilingResult.loops.back().getBody()->getTerminator()); 240 SmallVector<Operation *> tiledImplementation = op.getTiledImplementation( 241 rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); 242 if (tiledImplementation.size() != 1) { 243 return rewriter.notifyMatchFailure( 244 op, "expected tiled implementation to return a single op"); 245 } 246 tilingResult.tiledOp = tiledImplementation[0]; 247 248 LLVM_DEBUG({ 249 if (!tilingResult.loops.empty()) { 250 llvm::errs() << "After tiled implementation :\n"; 251 tilingResult.loops.front().dump(); 252 llvm::errs() << "\n"; 253 } 254 }); 255 } 256 257 if (op->getNumResults() == 0) { 258 rewriter.eraseOp(op); 259 return tilingResult; 260 } 261 262 // 5. If the original operations has results, modify the loop nest to yield 263 // the replacement values. 264 SmallVector<Value> replacements; 265 if (tilingResult.loops.empty()) { 266 // 5a. If there were no loops, the tiled implementation results are the 267 // replacements. 268 rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); 269 return tilingResult; 270 } 271 272 // 5b. `scf.for` with tensor semantics requires the loop nest to yield the 273 // replacement values using destructive updates. Use the `TilingInterface` 274 // to get the position of the result tiles and use that to generate the 275 // destructive update pattern, i.e., 276 // 277 // ```mlir 278 // scf.for %iv0 = ... { 279 // %0 = tiled_op 280 // } 281 // ``` 282 // 283 // is transformed to 284 // 285 // ```mlir 286 // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { 287 // %0 = tiled_op 288 // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] 289 // scf.yield %1 290 // } 291 // ``` 292 NewYieldValueFn yieldValueFn = 293 [&](OpBuilder &b, Location loc, 294 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 295 SmallVector<Value> yieldedValues; 296 Attribute one = b.getIndexAttr(1); 297 for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) { 298 SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes; 299 if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, 300 resultTileOffsets, 301 resultTileSizes))) { 302 op.emitOpError("unable to get position of result ") 303 << resultNum << " of the tiled implementation"; 304 return {}; 305 } 306 SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(), 307 one); 308 Value yieldedValue = b.create<tensor::InsertSliceOp>( 309 op->getLoc(), tilingResult.tiledOp->getResult(resultNum), 310 newBBArgs[resultNum], resultTileOffsets, resultTileSizes, 311 resultTileStrides); 312 yieldedValues.push_back(yieldedValue); 313 } 314 return yieldedValues; 315 }; 316 SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields( 317 rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), 318 yieldValueFn); 319 for (const auto &loop : llvm::enumerate(tilingResult.loops)) { 320 rewriter.eraseOp(loop.value()); 321 tilingResult.loops[loop.index()] = newLoops[loop.index()]; 322 } 323 rewriter.replaceOp(op, tilingResult.loops.front().getResults()); 324 return tilingResult; 325 } 326 327 //===----------------------------------------------------------------------===// 328 // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. 329 //===----------------------------------------------------------------------===// 330 331 scf::TileConsumerAndFuseProducersUsingSCFForOp:: 332 TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, 333 scf::SCFTilingOptions options, 334 PatternBenefit benefit) 335 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 336 tilingPattern(context, std::move(options)) {} 337 338 scf::TileConsumerAndFuseProducersUsingSCFForOp:: 339 TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, 340 MLIRContext *context, 341 scf::SCFTilingOptions options, 342 PatternBenefit benefit) 343 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 344 tilingPattern(context, std::move(options)) {} 345 346 /// Return the `Value` that is defined by an operation that implements 347 /// the `TilingInterface`. Looks through `iter_args` of scf.for nest 348 /// if required. 349 static Optional<OpResult> getFusableProducer(Value v) { 350 while (auto blockArg = v.dyn_cast<BlockArgument>()) { 351 auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp()); 352 if (!loopOp) 353 return llvm::None; 354 v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); 355 } 356 if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp())) 357 return llvm::None; 358 return v.cast<OpResult>(); 359 } 360 361 // Replace iter args of the outer most loop with region args of the inner most 362 // one. 363 static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, 364 PatternRewriter &rewriter) { 365 assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && 366 "expect same number of iter args"); 367 Block *block = &(*innerFor.getRegion().begin()); 368 for (auto it : 369 llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) { 370 Value source = std::get<0>(it); 371 Value target = std::get<1>(it); 372 source.replaceUsesWithIf(target, [&](OpOperand &use) { 373 return use.getOwner()->getBlock() == block; 374 }); 375 } 376 } 377 378 FailureOr<scf::SCFTileAndFuseResult> 379 scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( 380 TilingInterface op, PatternRewriter &rewriter) const { 381 // This transformation is only valid for ops that return values (i.e. not 382 // valid to use with operations that have memref operands). 383 if (!op->getNumResults()) { 384 return rewriter.notifyMatchFailure( 385 op, "invalid pattern for op with no results"); 386 } 387 388 // 1. First tile the consumer. 389 SCFTileAndFuseResult tileAndFuseResult; 390 { 391 FailureOr<SCFTilingResult> tilingResult = 392 tilingPattern.returningMatchAndRewrite(op, rewriter); 393 if (failed(tilingResult)) { 394 return failure(); 395 } 396 tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); 397 tileAndFuseResult.loops = std::move(tilingResult->loops); 398 } 399 400 // 2. Typically, the operands of the tiled operation are slices of the 401 // operands of the untiled operation. These are expressed in IR using 402 // `tensor.extract_slice` operations with source being the operands of the 403 // untiled operation. Create a worklist of these `tensor.extract_slice` 404 // operations. If the producers of the source of the `tensor.extract_slice` 405 // can be tiled such that the tiled value is generated in-place, that 406 // effectively tiles + fuses the operations. 407 auto addCandidateSlices = [](Operation *fusedOp, 408 std::deque<tensor::ExtractSliceOp> &candidates) { 409 for (Value operand : fusedOp->getOperands()) 410 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 411 candidates.push_back(sliceOp); 412 }; 413 414 std::deque<tensor::ExtractSliceOp> candidates; 415 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 416 OpBuilder::InsertionGuard g(rewriter); 417 while (!candidates.empty()) { 418 // 2a. Traverse the slices in BFS fashion. 419 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 420 candidates.pop_front(); 421 422 // 2b. Get the producer of the source (potentially walking through 423 // `iter_args` of nested `scf.for`) 424 Optional<OpResult> fusableProducer = 425 getFusableProducer(candidateSliceOp.getSource()); 426 if (!fusableProducer) 427 continue; 428 429 // 2c. Generate the tiled implementation of the producer of the source 430 rewriter.setInsertionPoint(candidateSliceOp); 431 FailureOr<Value> fusedProducerValue = 432 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 433 fusableProducer.value()); 434 if (failed(fusedProducerValue)) 435 continue; 436 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); 437 438 // 2d. The operands of the fused producer might themselved be slices of 439 // values produced by operations that implement the `TilingInterface`. 440 // Add these operations to the worklist. 441 Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 442 tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); 443 addCandidateSlices(fusedProducer, candidates); 444 445 // 2e. If the operation being fused creates a value that is used as `outs` 446 // in the tiled operation, the result of the unfused operation will be 447 // used in the `iter_args` of the tiled loop generated. When the 448 // operation is fused, this use in `iter_args` needs to be modified to 449 // use the destination of the fused operation. For example, starting 450 // with 451 // 452 // ```mlir 453 // %0 = linalg.init_tensor ... 454 // %1 = linalg.fill ... outs(%0:...)... 455 // %2 = linalg.matmul ... outs(%1:...).... 456 // ``` 457 // 458 // First the `linalg.matmul` gets tiled 459 // 460 // ```mlir 461 // %0 = linalg.init_tensor 462 // %1 = linalg.fill 463 // %2 = scf.for .... iter_args(%arg0 = %1)... 464 // ... 465 // ... = linalg.matmul ... 466 // 467 // ``` 468 // 469 // When the `linalg.fill` gets fused, the `iter_args` needs to be 470 // modified 471 // 472 // ```mlir 473 // %0 = linalg.init_tensor 474 // %1 = scf.for ... iter_args(%arg0 = %0)... 475 // ... 476 // %2 = linalg.fill ... 477 // %3 = linalg.matmul ... outs(%2: ...)... 478 // ``` 479 TilingInterface unfusedProducerOp = 480 cast<TilingInterface>(fusableProducer->getOwner()); 481 scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); 482 SmallVector<Value> unfusedProducerOpDestValues = 483 unfusedProducerOp.getDestinationOperands(rewriter); 484 for (OpOperand &uses : unfusedProducerOp->getUses()) { 485 if (uses.getOwner() == outerMostTiledLoop.getOperation()) { 486 unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber(); 487 unsigned operandNumber = uses.getOperandNumber(); 488 outerMostTiledLoop->setOperand( 489 operandNumber, unfusedProducerOpDestValues[resultNumber]); 490 } 491 } 492 } 493 replaceIterArgs(tileAndFuseResult.loops.front(), 494 tileAndFuseResult.loops.back(), rewriter); 495 return tileAndFuseResult; 496 } 497 498 //===----------------------------------------------------------------------===// 499 // LowerToLoopsUsingSCFForOp 500 //===----------------------------------------------------------------------===// 501 502 FailureOr<SmallVector<scf::ForOp>> 503 scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite( 504 TilingInterface op, PatternRewriter &rewriter) const { 505 SmallVector<Range> domain = op.getIterationDomain(rewriter); 506 507 // TODO: Handle cases where the op has results if needed. 508 if (op->getNumResults() > 0) { 509 return rewriter.notifyMatchFailure( 510 op, "unable to lower to loops operations with return values"); 511 } 512 513 SmallVector<Value> ivs; 514 SmallVector<scf::ForOp> loops; 515 Location loc = op.getLoc(); 516 for (auto loopRange : domain) { 517 Value offsetVal = 518 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); 519 Value sizeVal = 520 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); 521 Value strideVal = 522 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); 523 auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, 524 strideVal, ValueRange{}); 525 loops.push_back(loop); 526 ivs.push_back(loop.getInductionVar()); 527 rewriter.setInsertionPoint(loop.getBody()->getTerminator()); 528 } 529 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { 530 return failure(); 531 } 532 rewriter.eraseOp(op); 533 return loops; 534 } 535