1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the tiling using TilingInterface. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/SCF/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/IR/Matchers.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Interfaces/TilingInterface.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "tile-using-interface" 26 27 using namespace mlir; 28 29 scf::SCFTilingOptions & 30 scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 31 assert(!tileSizeComputationFunction && "tile sizes already set"); 32 SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end()); 33 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 34 OpBuilder::InsertionGuard guard(b); 35 b.setInsertionPointToStart( 36 &op->getParentOfType<func::FuncOp>().getBody().front()); 37 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 38 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 39 return v; 40 })); 41 }; 42 return *this; 43 } 44 45 /// Generate an empty loop nest that represents the tiled loop nest shell. 46 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. 47 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. 48 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of 49 /// the 50 /// tile processed within the inner most loop. 51 static SmallVector<scf::ForOp> 52 generateTileLoopNest(OpBuilder &builder, Location loc, 53 ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, 54 SmallVector<OpFoldResult> &offsets, 55 SmallVector<OpFoldResult> &sizes) { 56 assert(!loopRanges.empty() && "expected at least one loop range"); 57 assert(loopRanges.size() == tileSizeVals.size() && 58 "expected as many tile sizes as loop ranges"); 59 OpBuilder::InsertionGuard guard(builder); 60 SmallVector<scf::ForOp> loops; 61 offsets.resize(loopRanges.size()); 62 sizes.resize(loopRanges.size()); 63 64 // The tile size to use (to avoid out of bounds access) is minimum of 65 // `tileSize` and `ub - iv`, where `iv` is the induction variable 66 // of the tiled loop. 67 AffineExpr s0, s1, d0; 68 bindDims(builder.getContext(), d0); 69 bindSymbols(builder.getContext(), s0, s1); 70 AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext()); 71 72 for (auto loopRange : llvm::enumerate(loopRanges)) { 73 // No loops if tile size is zero. Set offset and size to the loop 74 // offset and size. 75 if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) { 76 offsets[loopRange.index()] = loopRange.value().offset; 77 sizes[loopRange.index()] = loopRange.value().size; 78 continue; 79 } 80 81 auto loop = builder.create<scf::ForOp>( 82 loc, loopRange.value().offset, loopRange.value().size, 83 tileSizeVals[loopRange.index()], ValueRange{}, 84 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, 85 ValueRange /*iterArgs*/) { 86 Value boundedTileSize = builder.create<AffineMinOp>( 87 bodyLoc, minMap, 88 ValueRange{iv, tileSizeVals[loopRange.index()], 89 loopRange.value().size}); 90 sizes[loopRange.index()] = boundedTileSize; 91 builder.create<scf::YieldOp>(loc); 92 }); 93 offsets[loopRange.index()] = loop.getInductionVar(); 94 loops.push_back(loop); 95 builder.setInsertionPoint(loop.getBody()->getTerminator()); 96 } 97 return loops; 98 } 99 100 scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, 101 scf::SCFTilingOptions options, 102 PatternBenefit benefit) 103 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 104 options(std::move(options)) {} 105 106 scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, 107 MLIRContext *context, 108 scf::SCFTilingOptions options, 109 PatternBenefit benefit) 110 : OpInterfaceRewritePattern<TilingInterface>(context, benefit), 111 options(std::move(options)) {} 112 113 FailureOr<scf::SCFTilingResult> 114 scf::TileUsingSCFForOp::returningMatchAndRewrite( 115 TilingInterface op, PatternRewriter &rewriter) const { 116 OpBuilder::InsertionGuard guard(rewriter); 117 rewriter.setInsertionPointAfter(op); 118 119 if (!options.tileSizeComputationFunction) { 120 return rewriter.notifyMatchFailure( 121 op, "missing tile size computation function"); 122 } 123 124 // 1. Get the range of the loops that are represented by the operation. 125 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); 126 size_t numLoops = iterationDomain.size(); 127 if (numLoops == 0) { 128 return rewriter.notifyMatchFailure( 129 op, "unable to tile op with no iteration domain"); 130 } 131 132 // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" 133 // skips tiling a particular dimension. This convention is significantly 134 // simpler to handle instead of adjusting affine maps to account for missing 135 // dimensions. 136 SmallVector<Value, 4> tileSizeVector = 137 options.tileSizeComputationFunction(rewriter, op); 138 if (tileSizeVector.size() < iterationDomain.size()) { 139 auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); 140 tileSizeVector.append(numLoops - tileSizeVector.size(), zero); 141 } 142 143 scf::SCFTilingResult tilingResult; 144 SmallVector<OpFoldResult> offsets, sizes; 145 { 146 // 3. Materialize an empty loop nest that iterates over the tiles. These 147 // loops for now do not return any values even if the original operation has 148 // results. 149 tilingResult.loops = generateTileLoopNest( 150 rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); 151 152 LLVM_DEBUG({ 153 if (!tilingResult.loops.empty()) { 154 llvm::errs() << "LoopNest shell :\n"; 155 tilingResult.loops.front().dump(); 156 llvm::errs() << "\n"; 157 } 158 }); 159 160 // 4. Generate the tiled implementation within the inner most loop. 161 if (!tilingResult.loops.empty()) 162 rewriter.setInsertionPoint( 163 tilingResult.loops.back().getBody()->getTerminator()); 164 SmallVector<Operation *> tiledImplementation = op.getTiledImplementation( 165 rewriter, op.getDestinationOperands(rewriter), offsets, sizes, true); 166 if (tiledImplementation.size() != 1) { 167 return rewriter.notifyMatchFailure( 168 op, "expected tiled implementation to return a single op"); 169 } 170 tilingResult.tiledOp = tiledImplementation[0]; 171 172 LLVM_DEBUG({ 173 if (!tilingResult.loops.empty()) { 174 llvm::errs() << "After tiled implementation :\n"; 175 tilingResult.loops.front().dump(); 176 llvm::errs() << "\n"; 177 } 178 }); 179 } 180 181 if (op->getNumResults() == 0) { 182 rewriter.eraseOp(op); 183 return tilingResult; 184 } 185 186 // 5. If the original operations has results, modify the loop nest to yield 187 // the replacement values. 188 SmallVector<Value> replacements; 189 if (tilingResult.loops.empty()) { 190 // 5a. If there were no loops, the tiled implementation results are the 191 // replacements. 192 rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); 193 return tilingResult; 194 } 195 196 // 5b. `scf.for` with tensor semantics requires the loop nest to yield the 197 // replacement values using destructive updates. Use the `TilingInterface` 198 // to get the position of the result tiles and use that to generate the 199 // destructive update pattern, i.e., 200 // 201 // ```mlir 202 // scf.for %iv0 = ... { 203 // %0 = tiled_op 204 // } 205 // ``` 206 // 207 // is transformed to 208 // 209 // ```mlir 210 // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { 211 // %0 = tiled_op 212 // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] 213 // scf.yield %1 214 // } 215 // ``` 216 NewYieldValueFn yieldValueFn = 217 [&](OpBuilder &b, Location loc, 218 ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { 219 SmallVector<Value> yieldedValues; 220 Attribute one = b.getIndexAttr(1); 221 for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) { 222 SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes; 223 if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, 224 resultTileOffsets, 225 resultTileSizes))) { 226 op.emitOpError("unable to get position of result ") 227 << resultNum << " of the tiled implementation"; 228 return {}; 229 } 230 SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(), 231 one); 232 Value yieldedValue = b.create<tensor::InsertSliceOp>( 233 op->getLoc(), tilingResult.tiledOp->getResult(resultNum), 234 newBBArgs[resultNum], resultTileOffsets, resultTileSizes, 235 resultTileStrides); 236 yieldedValues.push_back(yieldedValue); 237 } 238 return yieldedValues; 239 }; 240 SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields( 241 rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), 242 yieldValueFn); 243 for (auto loop : llvm::enumerate(tilingResult.loops)) { 244 rewriter.eraseOp(loop.value()); 245 tilingResult.loops[loop.index()] = newLoops[loop.index()]; 246 } 247 rewriter.replaceOp(op, tilingResult.loops.front().getResults()); 248 return tilingResult; 249 } 250