1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/SCF/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Interfaces/TilingInterface.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "tile-using-interface" 26 27 using namespace mlir; 28 29 scf::SCFTilingOptions & 30 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 31 assert(!tileSizeComputationFunction && "tile sizes already set"); 32 SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); 33 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 34 OpBuilder::InsertionGuard guard(b); 35 b.setInsertionPointToStart( 36 &op->getParentOfType<func::FuncOp>().getBody().front()); 37 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 38 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 39 return v; 40 })); 41 }; 42 return *this; 43 } 44 45 /// Helper method to adjust the interchange vector to match the iteration 46 /// domain. 47 static SmallVector<unsigned> 48 fillInterchangeVector(ArrayRef<unsigned> interchangeVector, 49 size_t iterationDomainSize) { 50 SmallVector<unsigned> filledVector = llvm::to_vector(interchangeVector); 51 if (filledVector.size() < iterationDomainSize) { 52 auto range = llvm::seq<unsigned>(filledVector.size(), iterationDomainSize); 53 filledVector.append(range.begin(), range.end()); 54 } 55 if (filledVector.size() > iterationDomainSize) 56 filledVector.resize(iterationDomainSize); 57 return filledVector; 58 } 59 60 /// Helper method to apply permutation to a vector 61 template <typename T> 62 static SmallVector<T> applyPermutationToVector(const SmallVector<T> &vector, 63 ArrayRef<unsigned> interchange) { 64 assert(interchange.size() == vector.size()); 65 return llvm::to_vector( 66 llvm::map_range(interchange, [&](unsigned val) { return vector[val]; })); 67 } 68 /// Helper method to apply to invert a permutation. 69 static SmallVector<unsigned> 70 invertPermutationVector(ArrayRef<unsigned> interchange) { 71 SmallVector<unsigned> inversion(interchange.size()); 72 for (auto pos : llvm::enumerate(interchange)) { 73 inversion[pos.value()] = pos.index(); 74 } 75 return inversion; 76 } 77 /// Method to check if an interchange vector is a permutation. 78 static bool isPermutation(ArrayRef<unsigned> interchange) { 79 llvm::SmallDenseSet<unsigned, 4> seenVals; 80 for (auto val : interchange) { 81 if (seenVals.count(val)) 82 return false; 83 seenVals.insert(val); 84 } 85 return seenVals.size() == interchange.size(); 86 } 87 88 //===----------------------------------------------------------------------===// 89 // TileUsingSCFForOp pattern implementation. 90 //===----------------------------------------------------------------------===// 91 92 /// Generate an empty loop nest that represents the tiled loop nest shell. 93 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 94 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 95 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 96 /// the 97 /// tile processed within the inner most loop. 98 static SmallVector<scf::ForOp> 99 generateTileLoopNest(OpBuilder &builder, Location loc, 100 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 101 SmallVector<OpFoldResult> &offsets, 102 SmallVector<OpFoldResult> &sizes) { 103 assert(!loopRanges.empty() && "expected at least one loop range"); 104 assert(loopRanges.size() == tileSizeVals.size() && 105 "expected as many tile sizes as loop ranges"); 106 OpBuilder::InsertionGuard guard(builder); 107 SmallVector<scf::ForOp> loops; 108 offsets.resize(loopRanges.size()); 109 sizes.resize(loopRanges.size()); 110 111 // The tile size to use (to avoid out of bounds access) is minimum of 112 // `tileSize` and `ub - iv`, where `iv` is the induction variable 113 // of the tiled loop. 114 AffineExpr s0, s1, d0; 115 bindDims(builder.getContext(), d0); 116 bindSymbols(builder.getContext(), s0, s1); 117 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 118 119 for (auto loopRange : llvm::enumerate(loopRanges)) { 120 // No loops if tile size is zero. Set offset and size to the loop 121 // offset and size. 122 if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 123 offsets[loopRange.index()] = loopRange.value().offset; 124 sizes[loopRange.index()] = loopRange.value().size; 125 continue; 126 } 127 128 auto loop = builder.create<scf::ForOp>( 129 loc, loopRange.value().offset, loopRange.value().size, 130 tileSizeVals[loopRange.index()], ValueRange{}, 131 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 132 ValueRange /*iterArgs*/) { 133 Value boundedTileSize = builder.create<AffineMinOp>( 134 bodyLoc, minMap, 135 ValueRange{iv, tileSizeVals[loopRange.index()], 136 loopRange.value().size}); 137 sizes[loopRange.index()] = boundedTileSize; 138 builder.create<scf::YieldOp>(loc); 139 }); 140 offsets[loopRange.index()] = loop.getInductionVar(); 141 loops.push_back(loop); 142 builder.setInsertionPoint(loop.getBody()->getTerminator()); 143 } 144 return loops; 145 } 146 147 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, 148 scf::SCFTilingOptions options, 149 PatternBenefit benefit) 150 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 151 options(std::move(options)) {} 152 153 scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, 154 MLIRContext *context, 155 scf::SCFTilingOptions options, 156 PatternBenefit benefit) 157 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 158 options(std::move(options)) {} 159 160 FailureOr<scf::SCFTilingResult> 161 scf::TileUsingSCFForOp::returningMatchAndRewrite( 162 TilingInterface op, PatternRewriter &rewriter) const { 163 OpBuilder::InsertionGuard guard(rewriter); 164 rewriter.setInsertionPointAfter(op); 165 166 if (!options.tileSizeComputationFunction) { 167 return rewriter.notifyMatchFailure( 168 op, "missing tile size computation function"); 169 } 170 171 // 1. Get the range of the loops that are represented by the operation. 172 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 173 size_t numLoops = iterationDomain.size(); 174 if (numLoops == 0) { 175 return rewriter.notifyMatchFailure( 176 op, "unable to tile op with no iteration domain"); 177 } 178 179 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 180 // skips tiling a particular dimension. This convention is significantly 181 // simpler to handle instead of adjusting affine maps to account for missing 182 // dimensions. 183 SmallVector<Value> tileSizeVector = 184 options.tileSizeComputationFunction(rewriter, op); 185 if (tileSizeVector.size() < iterationDomain.size()) { 186 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 187 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 188 } 189 190 scf::SCFTilingResult tilingResult; 191 SmallVector<OpFoldResult> offsets, sizes; 192 { 193 // If there is an interchange specified, permute the iteration domain and 194 // the tile sizes. 195 SmallVector<unsigned> interchangeVector; 196 if (!options.interchangeVector.empty()) { 197 interchangeVector = fillInterchangeVector(options.interchangeVector, 198 iterationDomain.size()); 199 } 200 if (!interchangeVector.empty()) { 201 if (!isPermutation(interchangeVector)) { 202 return rewriter.notifyMatchFailure( 203 op, "invalid intechange vector, not a permutation of the entire " 204 "iteration space"); 205 } 206 207 iterationDomain = 208 applyPermutationToVector(iterationDomain, interchangeVector); 209 tileSizeVector = 210 applyPermutationToVector(tileSizeVector, interchangeVector); 211 } 212 213 // 3. Materialize an empty loop nest that iterates over the tiles. These 214 // loops for now do not return any values even if the original operation has 215 // results. 216 tilingResult.loops = generateTileLoopNest( 217 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 218 219 if (!interchangeVector.empty()) { 220 auto inversePermutation = invertPermutationVector(interchangeVector); 221 offsets = applyPermutationToVector(offsets, inversePermutation); 222 sizes = applyPermutationToVector(sizes, inversePermutation); 223 } 224 225 LLVM_DEBUG({ 226 if (!tilingResult.loops.empty()) { 227 llvm::errs() << "LoopNest shell :\n"; 228 tilingResult.loops.front().dump(); 229 llvm::errs() << "\n"; 230 } 231 }); 232 233 // 4. Generate the tiled implementation within the inner most loop. 234 if (!tilingResult.loops.empty()) 235 rewriter.setInsertionPoint( 236 tilingResult.loops.back().getBody()->getTerminator()); 237 SmallVector<Operation *> tiledImplementation = op.getTiledImplementation( 238 rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); 239 if (tiledImplementation.size() != 1) { 240 return rewriter.notifyMatchFailure( 241 op, "expected tiled implementation to return a single op"); 242 } 243 tilingResult.tiledOp = tiledImplementation[0]; 244 245 LLVM_DEBUG({ 246 if (!tilingResult.loops.empty()) { 247 llvm::errs() << "After tiled implementation :\n"; 248 tilingResult.loops.front().dump(); 249 llvm::errs() << "\n"; 250 } 251 }); 252 } 253 254 if (op->getNumResults() == 0) { 255 rewriter.eraseOp(op); 256 return tilingResult; 257 } 258 259 // 5. If the original operations has results, modify the loop nest to yield 260 // the replacement values. 261 SmallVector<Value> replacements; 262 if (tilingResult.loops.empty()) { 263 // 5a. If there were no loops, the tiled implementation results are the 264 // replacements. 265 rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); 266 return tilingResult; 267 } 268 269 // 5b. `scf.for` with tensor semantics requires the loop nest to yield the 270 // replacement values using destructive updates. Use the `TilingInterface` 271 // to get the position of the result tiles and use that to generate the 272 // destructive update pattern, i.e., 273 // 274 // ```mlir 275 // scf.for %iv0 = ... { 276 // %0 = tiled_op 277 // } 278 // ``` 279 // 280 // is transformed to 281 // 282 // ```mlir 283 // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { 284 // %0 = tiled_op 285 // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] 286 // scf.yield %1 287 // } 288 // ``` 289 NewYieldValueFn yieldValueFn = 290 [&](OpBuilder &b, Location loc, 291 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 292 SmallVector<Value> yieldedValues; 293 Attribute one = b.getIndexAttr(1); 294 for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) { 295 SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes; 296 if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, 297 resultTileOffsets, 298 resultTileSizes))) { 299 op.emitOpError("unable to get position of result ") 300 << resultNum << " of the tiled implementation"; 301 return {}; 302 } 303 SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(), 304 one); 305 Value yieldedValue = b.create<tensor::InsertSliceOp>( 306 op->getLoc(), tilingResult.tiledOp->getResult(resultNum), 307 newBBArgs[resultNum], resultTileOffsets, resultTileSizes, 308 resultTileStrides); 309 yieldedValues.push_back(yieldedValue); 310 } 311 return yieldedValues; 312 }; 313 SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields( 314 rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), 315 yieldValueFn); 316 for (const auto &loop : llvm::enumerate(tilingResult.loops)) { 317 rewriter.eraseOp(loop.value()); 318 tilingResult.loops[loop.index()] = newLoops[loop.index()]; 319 } 320 rewriter.replaceOp(op, tilingResult.loops.front().getResults()); 321 return tilingResult; 322 } 323 324 //===----------------------------------------------------------------------===// 325 // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. 326 //===----------------------------------------------------------------------===// 327 328 scf::TileConsumerAndFuseProducersUsingSCFForOp:: 329 TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, 330 scf::SCFTilingOptions options, 331 PatternBenefit benefit) 332 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 333 tilingPattern(context, std::move(options)) {} 334 335 scf::TileConsumerAndFuseProducersUsingSCFForOp:: 336 TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, 337 MLIRContext *context, 338 scf::SCFTilingOptions options, 339 PatternBenefit benefit) 340 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 341 tilingPattern(context, std::move(options)) {} 342 343 /// Return the `Value` that is defined by an operation that implements 344 /// the `TilingInterface`. Looks through `iter_args` of scf.for nest 345 /// if required. 346 static Optional<OpResult> getFusableProducer(Value v) { 347 while (auto blockArg = v.dyn_cast<BlockArgument>()) { 348 auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp()); 349 if (!loopOp) 350 return llvm::None; 351 v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); 352 } 353 if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp())) 354 return llvm::None; 355 return v.cast<OpResult>(); 356 } 357 358 // Replace iter args of the outer most loop with region args of the inner most 359 // one. 360 static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, 361 PatternRewriter &rewriter) { 362 assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && 363 "expect same number of iter args"); 364 Block *block = &(*innerFor.getRegion().begin()); 365 for (auto it : 366 llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) { 367 Value source = std::get<0>(it); 368 Value target = std::get<1>(it); 369 source.replaceUsesWithIf(target, [&](OpOperand &use) { 370 return use.getOwner()->getBlock() == block; 371 }); 372 } 373 } 374 375 FailureOr<scf::SCFTileAndFuseResult> 376 scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( 377 TilingInterface op, PatternRewriter &rewriter) const { 378 // This transformation is only valid for ops that return values (i.e. not 379 // valid to use with operations that have memref operands). 380 if (!op->getNumResults()) { 381 return rewriter.notifyMatchFailure( 382 op, "invalid pattern for op with no results"); 383 } 384 385 // 1. First tile the consumer. 386 SCFTileAndFuseResult tileAndFuseResult; 387 { 388 FailureOr<SCFTilingResult> tilingResult = 389 tilingPattern.returningMatchAndRewrite(op, rewriter); 390 if (failed(tilingResult)) { 391 return failure(); 392 } 393 tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); 394 tileAndFuseResult.loops = std::move(tilingResult->loops); 395 } 396 397 // 2. Typically, the operands of the tiled operation are slices of the 398 // operands of the untiled operation. These are expressed in IR using 399 // `tensor.extract_slice` operations with source being the operands of the 400 // untiled operation. Create a worklist of these `tensor.extract_slice` 401 // operations. If the producers of the source of the `tensor.extract_slice` 402 // can be tiled such that the tiled value is generated in-place, that 403 // effectively tiles + fuses the operations. 404 auto addCandidateSlices = [](Operation *fusedOp, 405 std::deque<tensor::ExtractSliceOp> &candidates) { 406 for (Value operand : fusedOp->getOperands()) 407 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 408 candidates.push_back(sliceOp); 409 }; 410 411 std::deque<tensor::ExtractSliceOp> candidates; 412 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 413 OpBuilder::InsertionGuard g(rewriter); 414 while (!candidates.empty()) { 415 // 2a. Traverse the slices in BFS fashion. 416 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 417 candidates.pop_front(); 418 419 // 2b. Get the producer of the source (potentially walking through 420 // `iter_args` of nested `scf.for`) 421 Optional<OpResult> fusableProducer = 422 getFusableProducer(candidateSliceOp.getSource()); 423 if (!fusableProducer) 424 continue; 425 426 // 2c. Generate the tiled implementation of the producer of the source 427 rewriter.setInsertionPoint(candidateSliceOp); 428 FailureOr<Value> fusedProducerValue = 429 tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, 430 fusableProducer.value()); 431 if (failed(fusedProducerValue)) 432 continue; 433 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); 434 435 // 2d. The operands of the fused producer might themselved be slices of 436 // values produced by operations that implement the `TilingInterface`. 437 // Add these operations to the worklist. 438 Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 439 tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); 440 addCandidateSlices(fusedProducer, candidates); 441 442 // 2e. If the operation being fused creates a value that is used as `outs` 443 // in the tiled operation, the result of the unfused operation will be 444 // used in the `iter_args` of the tiled loop generated. When the 445 // operation is fused, this use in `iter_args` needs to be modified to 446 // use the destination of the fused operation. For example, starting 447 // with 448 // 449 // ```mlir 450 // %0 = linalg.init_tensor ... 451 // %1 = linalg.fill ... outs(%0:...)... 452 // %2 = linalg.matmul ... outs(%1:...).... 453 // ``` 454 // 455 // First the `linalg.matmul` gets tiled 456 // 457 // ```mlir 458 // %0 = linalg.init_tensor 459 // %1 = linalg.fill 460 // %2 = scf.for .... iter_args(%arg0 = %1)... 461 // ... 462 // ... = linalg.matmul ... 463 // 464 // ``` 465 // 466 // When the `linalg.fill` gets fused, the `iter_args` needs to be 467 // modified 468 // 469 // ```mlir 470 // %0 = linalg.init_tensor 471 // %1 = scf.for ... iter_args(%arg0 = %0)... 472 // ... 473 // %2 = linalg.fill ... 474 // %3 = linalg.matmul ... outs(%2: ...)... 475 // ``` 476 TilingInterface unfusedProducerOp = 477 cast<TilingInterface>(fusableProducer->getOwner()); 478 scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); 479 SmallVector<Value> unfusedProducerOpDestValues = 480 unfusedProducerOp.getDestinationOperands(rewriter); 481 for (OpOperand &uses : unfusedProducerOp->getUses()) { 482 if (uses.getOwner() == outerMostTiledLoop.getOperation()) { 483 unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber(); 484 unsigned operandNumber = uses.getOperandNumber(); 485 outerMostTiledLoop->setOperand( 486 operandNumber, unfusedProducerOpDestValues[resultNumber]); 487 } 488 } 489 } 490 replaceIterArgs(tileAndFuseResult.loops.front(), 491 tileAndFuseResult.loops.back(), rewriter); 492 return tileAndFuseResult; 493 } 494