1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/SCF/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Interfaces/TilingInterface.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "tile-using-interface" 26 27 using namespace mlir; 28 29 scf::SCFTilingOptions & 30 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 31 assert(!tileSizeComputationFunction && "tile sizes already set"); 32 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 33 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 34 OpBuilder::InsertionGuard guard(b); 35 b.setInsertionPointToStart( 36 &op->getParentOfType<func::FuncOp>().getBody().front()); 37 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 38 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 39 return v; 40 })); 41 }; 42 return *this; 43 } 44 45 //===----------------------------------------------------------------------===// 46 // TileUsingSCFForOp pattern implementation. 47 //===----------------------------------------------------------------------===// 48 49 /// Generate an empty loop nest that represents the tiled loop nest shell. 50 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 51 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 52 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 53 /// the 54 /// tile processed within the inner most loop. 55 static SmallVector<scf::ForOp> 56 generateTileLoopNest(OpBuilder &builder, Location loc, 57 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 58 SmallVector<OpFoldResult> &offsets, 59 SmallVector<OpFoldResult> &sizes) { 60 assert(!loopRanges.empty() && "expected at least one loop range"); 61 assert(loopRanges.size() == tileSizeVals.size() && 62 "expected as many tile sizes as loop ranges"); 63 OpBuilder::InsertionGuard guard(builder); 64 SmallVector<scf::ForOp> loops; 65 offsets.resize(loopRanges.size()); 66 sizes.resize(loopRanges.size()); 67 68 // The tile size to use (to avoid out of bounds access) is minimum of 69 // `tileSize` and `ub - iv`, where `iv` is the induction variable 70 // of the tiled loop. 71 AffineExpr s0, s1, d0; 72 bindDims(builder.getContext(), d0); 73 bindSymbols(builder.getContext(), s0, s1); 74 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 75 76 for (auto loopRange : llvm::enumerate(loopRanges)) { 77 // No loops if tile size is zero. Set offset and size to the loop 78 // offset and size. 79 if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 80 offsets[loopRange.index()] = loopRange.value().offset; 81 sizes[loopRange.index()] = loopRange.value().size; 82 continue; 83 } 84 85 auto loop = builder.create<scf::ForOp>( 86 loc, loopRange.value().offset, loopRange.value().size, 87 tileSizeVals[loopRange.index()], ValueRange{}, 88 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 89 ValueRange /*iterArgs*/) { 90 Value boundedTileSize = builder.create<AffineMinOp>( 91 bodyLoc, minMap, 92 ValueRange{iv, tileSizeVals[loopRange.index()], 93 loopRange.value().size}); 94 sizes[loopRange.index()] = boundedTileSize; 95 builder.create<scf::YieldOp>(loc); 96 }); 97 offsets[loopRange.index()] = loop.getInductionVar(); 98 loops.push_back(loop); 99 builder.setInsertionPoint(loop.getBody()->getTerminator()); 100 } 101 return loops; 102 } 103 104 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, 105 scf::SCFTilingOptions options, 106 PatternBenefit benefit) 107 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 108 options(std::move(options)) {} 109 110 scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, 111 MLIRContext *context, 112 scf::SCFTilingOptions options, 113 PatternBenefit benefit) 114 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 115 options(std::move(options)) {} 116 117 FailureOr<scf::SCFTilingResult> 118 scf::TileUsingSCFForOp::returningMatchAndRewrite( 119 TilingInterface op, PatternRewriter &rewriter) const { 120 OpBuilder::InsertionGuard guard(rewriter); 121 rewriter.setInsertionPointAfter(op); 122 123 if (!options.tileSizeComputationFunction) { 124 return rewriter.notifyMatchFailure( 125 op, "missing tile size computation function"); 126 } 127 128 // 1. Get the range of the loops that are represented by the operation. 129 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 130 size_t numLoops = iterationDomain.size(); 131 if (numLoops == 0) { 132 return rewriter.notifyMatchFailure( 133 op, "unable to tile op with no iteration domain"); 134 } 135 136 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 137 // skips tiling a particular dimension. This convention is significantly 138 // simpler to handle instead of adjusting affine maps to account for missing 139 // dimensions. 140 SmallVector<Value, 4> tileSizeVector = 141 options.tileSizeComputationFunction(rewriter, op); 142 if (tileSizeVector.size() < iterationDomain.size()) { 143 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 144 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 145 } 146 147 scf::SCFTilingResult tilingResult; 148 SmallVector<OpFoldResult> offsets, sizes; 149 { 150 // 3. Materialize an empty loop nest that iterates over the tiles. These 151 // loops for now do not return any values even if the original operation has 152 // results. 153 tilingResult.loops = generateTileLoopNest( 154 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 155 156 LLVM_DEBUG({ 157 if (!tilingResult.loops.empty()) { 158 llvm::errs() << "LoopNest shell :\n"; 159 tilingResult.loops.front().dump(); 160 llvm::errs() << "\n"; 161 } 162 }); 163 164 // 4. Generate the tiled implementation within the inner most loop. 165 if (!tilingResult.loops.empty()) 166 rewriter.setInsertionPoint( 167 tilingResult.loops.back().getBody()->getTerminator()); 168 SmallVector<Operation *> tiledImplementation = op.getTiledImplementation( 169 rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); 170 if (tiledImplementation.size() != 1) { 171 return rewriter.notifyMatchFailure( 172 op, "expected tiled implementation to return a single op"); 173 } 174 tilingResult.tiledOp = tiledImplementation[0]; 175 176 LLVM_DEBUG({ 177 if (!tilingResult.loops.empty()) { 178 llvm::errs() << "After tiled implementation :\n"; 179 tilingResult.loops.front().dump(); 180 llvm::errs() << "\n"; 181 } 182 }); 183 } 184 185 if (op->getNumResults() == 0) { 186 rewriter.eraseOp(op); 187 return tilingResult; 188 } 189 190 // 5. If the original operations has results, modify the loop nest to yield 191 // the replacement values. 192 SmallVector<Value> replacements; 193 if (tilingResult.loops.empty()) { 194 // 5a. If there were no loops, the tiled implementation results are the 195 // replacements. 196 rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); 197 return tilingResult; 198 } 199 200 // 5b. `scf.for` with tensor semantics requires the loop nest to yield the 201 // replacement values using destructive updates. Use the `TilingInterface` 202 // to get the position of the result tiles and use that to generate the 203 // destructive update pattern, i.e., 204 // 205 // ```mlir 206 // scf.for %iv0 = ... { 207 // %0 = tiled_op 208 // } 209 // ``` 210 // 211 // is transformed to 212 // 213 // ```mlir 214 // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { 215 // %0 = tiled_op 216 // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] 217 // scf.yield %1 218 // } 219 // ``` 220 NewYieldValueFn yieldValueFn = 221 [&](OpBuilder &b, Location loc, 222 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 223 SmallVector<Value> yieldedValues; 224 Attribute one = b.getIndexAttr(1); 225 for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) { 226 SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes; 227 if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, 228 resultTileOffsets, 229 resultTileSizes))) { 230 op.emitOpError("unable to get position of result ") 231 << resultNum << " of the tiled implementation"; 232 return {}; 233 } 234 SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(), 235 one); 236 Value yieldedValue = b.create<tensor::InsertSliceOp>( 237 op->getLoc(), tilingResult.tiledOp->getResult(resultNum), 238 newBBArgs[resultNum], resultTileOffsets, resultTileSizes, 239 resultTileStrides); 240 yieldedValues.push_back(yieldedValue); 241 } 242 return yieldedValues; 243 }; 244 SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields( 245 rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), 246 yieldValueFn); 247 for (auto loop : llvm::enumerate(tilingResult.loops)) { 248 rewriter.eraseOp(loop.value()); 249 tilingResult.loops[loop.index()] = newLoops[loop.index()]; 250 } 251 rewriter.replaceOp(op, tilingResult.loops.front().getResults()); 252 return tilingResult; 253 } 254 255 //===----------------------------------------------------------------------===// 256 // TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. 257 //===----------------------------------------------------------------------===// 258 259 scf::TileConsumerAndFuseProducersUsingSCFForOp:: 260 TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, 261 scf::SCFTilingOptions options, 262 PatternBenefit benefit) 263 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 264 tilingPattern(context, std::move(options)) {} 265 266 scf::TileConsumerAndFuseProducersUsingSCFForOp:: 267 TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, 268 MLIRContext *context, 269 scf::SCFTilingOptions options, 270 PatternBenefit benefit) 271 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 272 tilingPattern(context, std::move(options)) {} 273 274 /// Return the `Value` that is defined by an operation that implements 275 /// the `TilingInterface`. Looks through `iter_args` of scf.for nest 276 /// if required. 277 static Optional<OpResult> getFusableProducer(Value v) { 278 while (auto blockArg = v.dyn_cast<BlockArgument>()) { 279 auto loopOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp()); 280 if (!loopOp) 281 return llvm::None; 282 v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); 283 } 284 if (!isa_and_nonnull<TilingInterface>(v.getDefiningOp())) 285 return llvm::None; 286 return v.cast<OpResult>(); 287 } 288 289 FailureOr<scf::SCFTileAndFuseResult> 290 scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( 291 TilingInterface op, PatternRewriter &rewriter) const { 292 // This transformation is only valid for ops that return values (i.e. not 293 // valid to use with operations that have memref operands). 294 if (!op->getNumResults()) { 295 return rewriter.notifyMatchFailure( 296 op, "invalid pattern for op with no results"); 297 } 298 299 // 1. First tile the consumer. 300 SCFTileAndFuseResult tileAndFuseResult; 301 { 302 FailureOr<SCFTilingResult> tilingResult = 303 tilingPattern.returningMatchAndRewrite(op, rewriter); 304 if (failed(tilingResult)) { 305 return failure(); 306 } 307 tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); 308 tileAndFuseResult.loops = std::move(tilingResult->loops); 309 } 310 311 // 2. Typically, the operands of the tiled operation are slices of the 312 // operands of the untiled operation. These are expressed in IR using 313 // `tensor.extract_slice` operations with source being the operands of the 314 // untiled operation. Create a worklist of these `tensor.extract_slice` 315 // operations. If the producers of the source of the `tensor.extract_slice` 316 // can be tiled such that the tiled value is generated in-place, that 317 // effectively tiles + fuses the operations. 318 auto addCandidateSlices = [](Operation *fusedOp, 319 std::deque<tensor::ExtractSliceOp> &candidates) { 320 for (Value operand : fusedOp->getOperands()) 321 if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) 322 candidates.push_back(sliceOp); 323 }; 324 325 std::deque<tensor::ExtractSliceOp> candidates; 326 addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); 327 OpBuilder::InsertionGuard g(rewriter); 328 while (!candidates.empty()) { 329 // 2a. Traverse the slices in BFS fashion. 330 tensor::ExtractSliceOp candidateSliceOp = candidates.front(); 331 candidates.pop_front(); 332 333 // 2b. Get the producer of the source (potentially walking through 334 // `iter_args` of nested `scf.for`) 335 Optional<OpResult> fusableProducer = 336 getFusableProducer(candidateSliceOp.source()); 337 if (!fusableProducer) 338 continue; 339 340 // 2c. Generate the tiled implementation of the producer of the source 341 rewriter.setInsertionPoint(candidateSliceOp); 342 FailureOr<Value> fusedProducerValue = 343 tensor::replaceExtractSliceWithTiledProducer( 344 rewriter, candidateSliceOp, fusableProducer.getValue()); 345 if (failed(fusedProducerValue)) 346 continue; 347 rewriter.replaceOp(candidateSliceOp, fusedProducerValue.getValue()); 348 349 // 2d. The operands of the fused producer might themselved be slices of 350 // values produced by operations that implement the `TilingInterface`. 351 // Add these operations to the worklist. 352 Operation *fusedProducer = fusedProducerValue->getDefiningOp(); 353 tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); 354 addCandidateSlices(fusedProducer, candidates); 355 356 // 2e. If the operation being fused creates a value that is used as `outs` 357 // in the tiled operation, the result of the unfused operation will be 358 // used in the `iter_args` of the tiled loop generated. When the 359 // operation is fused, this use in `iter_args` needs to be modified to 360 // use the destination of the fused operation. For example, starting 361 // with 362 // 363 // ```mlir 364 // %0 = linalg.init_tensor ... 365 // %1 = linalg.fill ... outs(%0:...)... 366 // %2 = linalg.matmul ... outs(%1:...).... 367 // ``` 368 // 369 // First the `linalg.matmul` gets tiled 370 // 371 // ```mlir 372 // %0 = linalg.init_tensor 373 // %1 = linalg.fill 374 // %2 = scf.for .... iter_args(%arg0 = %1)... 375 // ... 376 // ... = linalg.matmul ... 377 // 378 // ``` 379 // 380 // When the `linalg.fill` gets fused, the `iter_args` needs to be 381 // modified 382 // 383 // ```mlir 384 // %0 = linalg.init_tensor 385 // %1 = scf.for ... iter_args(%arg0 = %0)... 386 // ... 387 // %2 = linalg.fill ... 388 // %3 = linalg.matmul ... outs(%2: ...)... 389 // ``` 390 TilingInterface unfusedProducerOp = 391 cast<TilingInterface>(fusableProducer->getOwner()); 392 scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); 393 SmallVector<Value> unfusedProducerOpDestValues = 394 unfusedProducerOp.getDestinationOperands(rewriter); 395 for (OpOperand &uses : unfusedProducerOp->getUses()) { 396 if (uses.getOwner() == outerMostTiledLoop.getOperation()) { 397 unsigned resultNumber = uses.get().cast<OpResult>().getResultNumber(); 398 unsigned operandNumber = uses.getOperandNumber(); 399 outerMostTiledLoop->setOperand( 400 operandNumber, unfusedProducerOpDestValues[resultNumber]); 401 } 402 } 403 } 404 return tileAndFuseResult; 405 } 406