1 //===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H 10 #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H 11 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 14 #include "mlir/IR/PatternMatch.h" 15 #include "mlir/Interfaces/LoopLikeInterface.h" 16 #include "mlir/Interfaces/TilingInterface.h" 17 #include "mlir/Interfaces/ViewLikeInterface.h" 18 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 19 20 #include <deque> 21 22 namespace mlir { 23 class Operation; 24 class RewriterBase; 25 class TilingInterface; 26 } // namespace mlir 27 28 namespace mlir { 29 namespace scf { 30 31 using SCFTileSizeComputationFunction = 32 std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>; 33 34 /// Options to use to control tiling. 35 struct SCFTilingOptions { 36 /// Computation function that returns the tile sizes to use for each loop. 37 /// Returning a tile size of zero implies no tiling for that loop. If the 38 /// size of the returned vector is smaller than the number of loops, the inner 39 /// loops are not tiled. If the size of the returned vector is larger, then 40 /// the vector is truncated to number of loops. 41 SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr; 42 43 SCFTilingOptions & 44 setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) { 45 tileSizeComputationFunction = std::move(fun); 46 return *this; 47 } 48 /// Convenience function to set the `tileSizeComputationFunction` to a 49 /// function that computes tile sizes at the point they are needed. Allows 50 /// proper interaction with folding. 51 SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes); 52 53 /// Computation function that returns the number of threads to use for 54 /// each loop. Returning a num threads of zero implies no tiling for that 55 /// loop. If the size of the returned vector is smaller than the number of 56 /// loops, the inner loops are not tiled. If the size of the returned vector 57 /// is larger, then the vector is truncated to number of loops. Note: This 58 /// option is only supported with loopType set to `LoopType::ForallOp`. If the 59 /// tile size function is not specified while the num threads computation is, 60 /// then the tile size is determined automatically to map at most one tile per 61 /// thread. 62 SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr; 63 64 SCFTilingOptions & 65 setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) { 66 numThreadsComputationFunction = std::move(fun); 67 return *this; 68 } 69 /// Convenience function to set the `numThreadsComputationFunction` to a 70 /// function that computes num threads at the point they are needed. 71 SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads); 72 73 /// The interchange vector to reorder the tiled loops. 74 SmallVector<int64_t> interchangeVector = {}; 75 SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) { 76 interchangeVector = llvm::to_vector(interchange); 77 return *this; 78 } 79 80 /// Specify which loop construct to use for tile and fuse. 81 enum class LoopType { ForOp, ForallOp }; 82 LoopType loopType = LoopType::ForOp; 83 SCFTilingOptions &setLoopType(LoopType type) { 84 loopType = type; 85 return *this; 86 } 87 88 /// Specify how reduction dimensions should be tiled. 89 /// 90 /// Tiling can be thought of as splitting a dimension into 2 and materializing 91 /// the outer dimension as a loop: 92 /// 93 /// op[original] -> op[original / x, x] -> loop[original] { op[x] } 94 /// 95 /// For parallel dimensions, the split can only happen in one way, with both 96 /// dimensions being parallel. For reduction dimensions however, there is a 97 /// choice in how we split the reduction dimension. This enum exposes this 98 /// choice. 99 enum class ReductionTilingStrategy { 100 // [reduction] -> [reduction1, reduction2] 101 // -> loop[reduction1] { [reduction2] } 102 FullReduction, 103 // [reduction] -> [reduction1, parallel2] 104 // -> loop[reduction1] { [parallel2] }; merge[reduction1] 105 PartialReductionOuterReduction, 106 // [reduction] -> [parallel1, reduction2] 107 // -> loop[parallel1] { [reduction2] }; merge[parallel1] 108 PartialReductionOuterParallel 109 }; 110 ReductionTilingStrategy reductionStrategy = 111 ReductionTilingStrategy::FullReduction; 112 SCFTilingOptions & 113 setReductionTilingStrategy(ReductionTilingStrategy strategy) { 114 reductionStrategy = strategy; 115 return *this; 116 } 117 118 /// Specify mapping of loops to devices. This is only respected when the loop 119 /// constructs support such a mapping (like `scf.forall`). Will be ignored 120 /// when using loop constructs that dont support such a mapping (like 121 /// `scf.for`) 122 SmallVector<Attribute> mappingVector = {}; 123 SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) { 124 mappingVector = llvm::to_vector(mapping); 125 return *this; 126 } 127 }; 128 129 /// Transformation information returned after tiling. 130 struct SCFTilingResult { 131 /// Tiled operations that are generated during tiling. The order does not 132 /// matter except the last op. The replacements are expected to be the results 133 /// of the last op. 134 SmallVector<Operation *> tiledOps; 135 /// The initial destination values passed to the tiled operations. 136 SmallVector<Value> initialValues; 137 /// The `scf.for` operations that iterate over the tiles. 138 SmallVector<LoopLikeOpInterface> loops; 139 /// The result generated by the loop nest in tiling, may hold partial results, 140 /// which need to be merged to match the computation of the untiled operation. 141 /// `mergeResult` contains the operations used to perform this merge from 142 /// partial results and the values that can be used as replacements of 143 /// the untiled operation. 144 MergeResult mergeResult; 145 /// Slices generated after tiling that can be used for fusing with the tiled 146 /// producer. 147 SmallVector<Operation *> generatedSlices; 148 }; 149 150 /// Method to tile an op that implements the `TilingInterface` using 151 /// `scf.for` for iterating over the tiles. 152 FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter, 153 TilingInterface op, 154 const SCFTilingOptions &options); 155 156 /// Options used to control tile + fuse. 157 struct SCFTileAndFuseOptions { 158 /// The tiling options used to control the tiling of the consumer. 159 SCFTilingOptions tilingOptions; 160 SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) { 161 tilingOptions = options; 162 return *this; 163 } 164 165 /// Control function to check if a slice needs to be fused or not, 166 /// The control function receives 167 /// 1) the slice along which fusion is to be done, 168 /// 2) the producer value that is to be fused 169 /// 3) a boolean value set to `true` if the fusion is from 170 /// a destination operand. 171 /// The control function returns an `std::optiona<ControlFnResult>`. 172 /// If the return value is `std::nullopt`, that implies no fusion 173 /// is to be performed along that slice. 174 struct ControlFnResult { 175 /// Set to true if the loop nest has to return a replacement value 176 /// for the fused producer. 177 bool yieldProducerReplacement = false; 178 }; 179 using ControlFnTy = std::function<std::optional<ControlFnResult>( 180 tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, 181 bool isDestinationOperand)>; 182 /// The default control function implements greedy fusion without yielding 183 /// a replacement for any of the fused results. 184 ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, 185 bool) -> std::optional<ControlFnResult> { 186 return ControlFnResult{}; 187 }; 188 SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) { 189 fusionControlFn = controlFn; 190 return *this; 191 } 192 193 /// An optional set of rewrite patterns to apply to the results of tiling 194 /// before fusion. This will track deleted and newly inserted 195 /// `tensor.extract_slice` ops and update the worklist. 196 std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt; 197 }; 198 199 /// Fuse the producer of the source of `candidateSliceOp` by computing the 200 /// required slice of the producer in-place. Note that the method 201 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer 202 /// value but does not delete the slice operation. 203 struct SCFFuseProducerOfSliceResult { 204 OpResult origProducer; // Original untiled producer. 205 Value tiledAndFusedProducer; // Tile and fused producer value. 206 SmallVector<Operation *> tiledOps; 207 SmallVector<Operation *> generatedSlices; 208 }; 209 std::optional<SCFFuseProducerOfSliceResult> 210 tileAndFuseProducerOfSlice(RewriterBase &rewriter, 211 tensor::ExtractSliceOp candidateSliceOp, 212 MutableArrayRef<LoopLikeOpInterface> loops); 213 214 /// Reconstruct the fused producer from within the tiled-and-fused code. Based 215 /// on the slice of the producer computed in place it is possible that within 216 /// the loop nest same slice of the producer is computed multiple times. It is 217 /// in general not possible to recompute the value of the fused producer from 218 /// the tiled loop code in such cases. For the cases where no slice of the 219 /// producer is computed in a redundant fashion it is possible to reconstruct 220 /// the value of the original producer from within the tiled loop. It is upto 221 /// the caller to ensure that the producer is not computed redundantly within 222 /// the tiled loop nest. For example, consider 223 /// 224 /// ```mlir 225 /// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> 226 /// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32> 227 /// ``` 228 /// 229 /// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR 230 /// is, 231 /// 232 /// ```mlir 233 /// %t1_0 = scf.for .... iter_args(%arg0 = ...) { 234 /// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) { 235 /// ... 236 /// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> 237 /// %t1_3 = linalg.matmul ins(%t1_2, ...) 238 /// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ... 239 /// scf.yield %t1_4 240 /// } 241 /// scf.yield %t1_1 242 /// } 243 /// ``` 244 /// 245 /// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead 246 /// if `%1` were tiled only along the rows, the resultant code would be 247 /// 248 /// ```mlir 249 /// %t2_0 = scf.for .... iter_args(%arg0 = ...) { 250 /// ... 251 /// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32> 252 /// %t2_2 = linalg.matmul ins(%t2_1, ...) 253 /// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ... 254 /// scf.yield %t2_3 255 /// } 256 /// ``` 257 /// 258 /// Here there is no intersection in the different slices of `%t2_1` computed 259 /// across iterations of the `scf.for`. In such cases, the value of the original 260 /// `%0` can be reconstructed from within the loop body. This is useful in cases 261 /// where `%0` had other uses as well. If not reconstructed from within the loop 262 /// body, uses of `%0` could not be replaced, making it still live and the 263 /// fusion immaterial. 264 /// 265 /// The @param `yieldResultNumber` decides which result would be yield. If not 266 /// given, yield all `opResult` of fused producer. 267 /// 268 /// The method returns the list of new slices added during the process (which 269 /// can be used to fuse along). 270 FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer( 271 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, 272 scf::SCFFuseProducerOfSliceResult fusedProducerInfo, 273 MutableArrayRef<LoopLikeOpInterface> loops, 274 ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{}); 275 276 /// Transformation information returned after tile and fuse. 277 struct SCFTileAndFuseResult { 278 /// List of untiled operations that were fused with the tiled consumer. 279 llvm::SetVector<Operation *> fusedProducers; 280 /// List of tiled and fused operations generated. The first one in this list 281 /// is guaranteed to be the tiled operations generated during tiling of the 282 /// generated operation. 283 llvm::SetVector<Operation *> tiledAndFusedOps; 284 /// The `scf.for` operations that iterate over the tiles. 285 SmallVector<LoopLikeOpInterface> loops; 286 /// The replacement values to use for the tiled and fused operations. 287 llvm::DenseMap<Value, Value> replacements; 288 }; 289 290 /// Method to tile and fuse a sequence of operations, by tiling the consumer 291 /// and fusing its producers. Note that this assumes that it is valid to 292 /// tile+fuse the producer into the innermost tiled loop. Its up to the caller 293 /// to ensure that the tile sizes provided make this fusion valid. 294 /// 295 /// For example, for the following sequence 296 /// 297 /// ```mlir 298 /// %0 = 299 /// %1 = linalg.fill ... outs(%0 : ... ) 300 /// %2 = linalg.matmul ... outs(%1 : ...) ... 301 /// ``` 302 /// 303 /// it is legal to fuse the fill with the matmul only if the matmul is tiled 304 /// along the parallel dimensions and not the reduction dimension, i.e. the tile 305 /// size for the reduction dimension should be 0. The resulting fused 306 /// transformation is 307 /// 308 /// ```mlir 309 /// %1 = scf.for ... iter_args(%arg0 = %0) 310 /// %2 = tensor.extract_slice %arg0 311 /// %3 = linalg.fill .. outs(%2 : ... ) 312 /// %4 = linalg.matmul .. outs(%3 : ...) 313 /// } 314 /// ``` 315 FailureOr<SCFTileAndFuseResult> 316 tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, 317 TilingInterface consumer, 318 const SCFTileAndFuseOptions &options); 319 320 /// Fuse the consumer of the source of `candidateSliceOp` by computing the 321 /// required slice of the consumer in-place. Note that the method 322 /// replaces the uses of `candidateSliceOp` with the tiled and fused consumer 323 /// value but does not delete the slice operation. 324 struct SCFFuseConsumerOfSliceResult { 325 OpOperand *origConsumerOperand; // Original untiled consumer's operand. 326 OpOperand 327 *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand. 328 SmallVector<Operation *> tiledOps; 329 }; 330 FailureOr<scf::SCFFuseConsumerOfSliceResult> 331 tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp); 332 333 /// Method to lower an `op` that implements the `TilingInterface` to 334 /// loops/scalars. 335 FailureOr<SmallVector<scf::ForOp>> 336 lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); 337 338 /// Method to tile a reduction and generate a parallel op within a serial loop. 339 /// Each of the partial reductions are calculated in parallel. Then after the 340 /// loop all the partial reduction are merged into a final reduction. 341 /// For example for the following sequence 342 /// 343 /// ```mlir 344 /// %0 = linalg.generic %in ["parallel", "reduction"] 345 /// : tensor<7x9xf32> -> tensor<7xf32> 346 /// ``` 347 /// 348 /// into: 349 /// 350 /// ```mlir 351 /// %0 = linalg.fill ... : tensor<7x4xf32> 352 /// %1 = scf.for ... iter_args(%arg0 = %0) 353 /// %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32> 354 /// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> 355 /// %4 = linalg.generic %2, %3 ["parallel", "parallel"] 356 /// : tensor<7x?xf32> -> tensor<7x?xf32> 357 /// %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32> 358 /// } 359 /// %6 = linalg.generic %1 ["parallel", "reduction"] 360 /// : tensor<7x4xf32> -> tensor<7xf32> 361 /// ``` 362 FailureOr<scf::SCFTilingResult> 363 tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, 364 ArrayRef<OpFoldResult> tileSize); 365 366 } // namespace scf 367 } // namespace mlir 368 369 #endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H 370