xref: /llvm-project/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (revision 4b56345895729fda3bc3c094bc3f237ba3a49686)
1 //===- TileUsingInterface.h - Tiling ops using TilingInterface --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
10 #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
11 
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/LoopLikeInterface.h"
16 #include "mlir/Interfaces/TilingInterface.h"
17 #include "mlir/Interfaces/ViewLikeInterface.h"
18 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
19 
20 #include <deque>
21 
22 namespace mlir {
23 class Operation;
24 class RewriterBase;
25 class TilingInterface;
26 } // namespace mlir
27 
28 namespace mlir {
29 namespace scf {
30 
31 using SCFTileSizeComputationFunction =
32     std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
33 
34 /// Options to use to control tiling.
35 struct SCFTilingOptions {
36   /// Computation function that returns the tile sizes to use for each loop.
37   /// Returning a tile size of zero implies no tiling for that loop. If the
38   /// size of the returned vector is smaller than the number of loops, the inner
39   /// loops are not tiled. If the size of the returned vector is larger, then
40   /// the vector is truncated to number of loops.
41   SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
42 
43   SCFTilingOptions &
44   setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
45     tileSizeComputationFunction = std::move(fun);
46     return *this;
47   }
48   /// Convenience function to set the `tileSizeComputationFunction` to a
49   /// function that computes tile sizes at the point they are needed. Allows
50   /// proper interaction with folding.
51   SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
52 
53   /// Computation function that returns the number of threads to use for
54   /// each loop. Returning a num threads of zero implies no tiling for that
55   /// loop. If the size of the returned vector is smaller than the number of
56   /// loops, the inner loops are not tiled. If the size of the returned vector
57   /// is larger, then the vector is truncated to number of loops. Note: This
58   /// option is only supported with loopType set to `LoopType::ForallOp`. If the
59   /// tile size function is not specified while the num threads computation is,
60   /// then the tile size is determined automatically to map at most one tile per
61   /// thread.
62   SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
63 
64   SCFTilingOptions &
65   setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
66     numThreadsComputationFunction = std::move(fun);
67     return *this;
68   }
69   /// Convenience function to set the `numThreadsComputationFunction` to a
70   /// function that computes num threads at the point they are needed.
71   SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
72 
73   /// The interchange vector to reorder the tiled loops.
74   SmallVector<int64_t> interchangeVector = {};
75   SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
76     interchangeVector = llvm::to_vector(interchange);
77     return *this;
78   }
79 
80   /// Specify which loop construct to use for tile and fuse.
81   enum class LoopType { ForOp, ForallOp };
82   LoopType loopType = LoopType::ForOp;
83   SCFTilingOptions &setLoopType(LoopType type) {
84     loopType = type;
85     return *this;
86   }
87 
88   /// Specify how reduction dimensions should be tiled.
89   ///
90   /// Tiling can be thought of as splitting a dimension into 2 and materializing
91   /// the outer dimension as a loop:
92   ///
93   /// op[original] -> op[original / x, x] -> loop[original] { op[x] }
94   ///
95   /// For parallel dimensions, the split can only happen in one way, with both
96   /// dimensions being parallel. For reduction dimensions however, there is a
97   /// choice in how we split the reduction dimension. This enum exposes this
98   /// choice.
99   enum class ReductionTilingStrategy {
100     // [reduction] -> [reduction1, reduction2]
101     // -> loop[reduction1] { [reduction2] }
102     FullReduction,
103     // [reduction] -> [reduction1, parallel2]
104     // -> loop[reduction1] { [parallel2] }; merge[reduction1]
105     PartialReductionOuterReduction,
106     // [reduction] -> [parallel1, reduction2]
107     // -> loop[parallel1] { [reduction2] }; merge[parallel1]
108     PartialReductionOuterParallel
109   };
110   ReductionTilingStrategy reductionStrategy =
111       ReductionTilingStrategy::FullReduction;
112   SCFTilingOptions &
113   setReductionTilingStrategy(ReductionTilingStrategy strategy) {
114     reductionStrategy = strategy;
115     return *this;
116   }
117 
118   /// Specify mapping of loops to devices. This is only respected when the loop
119   /// constructs support such a mapping (like `scf.forall`). Will be ignored
120   /// when using loop constructs that dont support such a mapping (like
121   /// `scf.for`)
122   SmallVector<Attribute> mappingVector = {};
123   SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
124     mappingVector = llvm::to_vector(mapping);
125     return *this;
126   }
127 };
128 
129 /// Transformation information returned after tiling.
130 struct SCFTilingResult {
131   /// Tiled operations that are generated during tiling. The order does not
132   /// matter except the last op. The replacements are expected to be the results
133   /// of the last op.
134   SmallVector<Operation *> tiledOps;
135   /// The initial destination values passed to the tiled operations.
136   SmallVector<Value> initialValues;
137   /// The `scf.for` operations that iterate over the tiles.
138   SmallVector<LoopLikeOpInterface> loops;
139   /// The result generated by the loop nest in tiling, may hold partial results,
140   /// which need to be merged to match the computation of the untiled operation.
141   /// `mergeResult` contains the operations used to perform this merge from
142   /// partial results and the values that can be used as replacements of
143   /// the untiled operation.
144   MergeResult mergeResult;
145   /// Slices generated after tiling that can be used for fusing with the tiled
146   /// producer.
147   SmallVector<Operation *> generatedSlices;
148 };
149 
150 /// Method to tile an op that implements the `TilingInterface` using
151 /// `scf.for` for iterating over the tiles.
152 FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
153                                         TilingInterface op,
154                                         const SCFTilingOptions &options);
155 
156 /// Options used to control tile + fuse.
157 struct SCFTileAndFuseOptions {
158   /// The tiling options used to control the tiling of the consumer.
159   SCFTilingOptions tilingOptions;
160   SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) {
161     tilingOptions = options;
162     return *this;
163   }
164 
165   /// Control function to check if a slice needs to be fused or not,
166   /// The control function receives
167   /// 1) the slice along which fusion is to be done,
168   /// 2) the producer value that is to be fused
169   /// 3) a boolean value set to `true` if the fusion is from
170   ///    a destination operand.
171   /// The control function returns an `std::optiona<ControlFnResult>`.
172   /// If the return value is `std::nullopt`, that implies no fusion
173   /// is to be performed along that slice.
174   struct ControlFnResult {
175     /// Set to true if the loop nest has to return a replacement value
176     /// for the fused producer.
177     bool yieldProducerReplacement = false;
178   };
179   using ControlFnTy = std::function<std::optional<ControlFnResult>(
180       tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
181       bool isDestinationOperand)>;
182   /// The default control function implements greedy fusion without yielding
183   /// a replacement for any of the fused results.
184   ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult,
185                                    bool) -> std::optional<ControlFnResult> {
186     return ControlFnResult{};
187   };
188   SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) {
189     fusionControlFn = controlFn;
190     return *this;
191   }
192 
193   /// An optional set of rewrite patterns to apply to the results of tiling
194   /// before fusion. This will track deleted and newly inserted
195   /// `tensor.extract_slice` ops and update the worklist.
196   std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
197 };
198 
199 /// Fuse the producer of the source of `candidateSliceOp` by computing the
200 /// required slice of the producer in-place.  Note that the method
201 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
202 /// value but does not delete the slice operation.
203 struct SCFFuseProducerOfSliceResult {
204   OpResult origProducer;       // Original untiled producer.
205   Value tiledAndFusedProducer; // Tile and fused producer value.
206   SmallVector<Operation *> tiledOps;
207   SmallVector<Operation *> generatedSlices;
208 };
209 std::optional<SCFFuseProducerOfSliceResult>
210 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
211                            tensor::ExtractSliceOp candidateSliceOp,
212                            MutableArrayRef<LoopLikeOpInterface> loops);
213 
214 /// Reconstruct the fused producer from within the tiled-and-fused code. Based
215 /// on the slice of the producer computed in place it is possible that within
216 /// the loop nest same slice of the producer is computed multiple times. It is
217 /// in general not possible to recompute the value of the fused producer from
218 /// the tiled loop code in such cases. For the cases where no slice of the
219 /// producer is computed in a redundant fashion it is possible to reconstruct
220 /// the value of the original producer from within the tiled loop. It is upto
221 /// the caller to ensure that the producer is not computed redundantly within
222 /// the tiled loop nest. For example, consider
223 ///
224 /// ```mlir
225 /// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
226 /// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32>
227 /// ```
228 ///
229 /// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR
230 /// is,
231 ///
232 /// ```mlir
233 /// %t1_0 = scf.for .... iter_args(%arg0 = ...) {
234 ///   %t1_1 = scf.for ... iter_args(%arg1 = %arg0) {
235 ///     ...
236 ///     %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
237 ///     %t1_3 = linalg.matmul ins(%t1_2, ...)
238 ///     %t1_4 = tensor.insert_slice %t1_3 into %arg1 ...
239 ///     scf.yield %t1_4
240 ///   }
241 ///   scf.yield %t1_1
242 /// }
243 /// ```
244 ///
245 /// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead
246 /// if `%1` were tiled only along the rows, the resultant code would be
247 ///
248 /// ```mlir
249 /// %t2_0 = scf.for .... iter_args(%arg0 = ...) {
250 ///   ...
251 ///   %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
252 ///   %t2_2 = linalg.matmul ins(%t2_1, ...)
253 ///   %t2_3 = tensor.insert_slice %t2_2 into %arg0 ...
254 ///   scf.yield %t2_3
255 /// }
256 /// ```
257 ///
258 /// Here there is no intersection in the different slices of `%t2_1` computed
259 /// across iterations of the `scf.for`. In such cases, the value of the original
260 /// `%0` can be reconstructed from within the loop body. This is useful in cases
261 /// where `%0` had other uses as well. If not reconstructed from within the loop
262 /// body, uses of `%0` could not be replaced, making it still live and the
263 /// fusion immaterial.
264 ///
265 /// The @param `yieldResultNumber` decides which result would be yield. If not
266 /// given, yield all `opResult` of fused producer.
267 ///
268 /// The method returns the list of new slices added during the process (which
269 /// can be used to fuse along).
270 FailureOr<SmallVector<Operation *>> yieldReplacementForFusedProducer(
271     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
272     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
273     MutableArrayRef<LoopLikeOpInterface> loops,
274     ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
275 
276 /// Transformation information returned after tile and fuse.
277 struct SCFTileAndFuseResult {
278   /// List of untiled operations that were fused with the tiled consumer.
279   llvm::SetVector<Operation *> fusedProducers;
280   /// List of tiled and fused operations generated. The first one in this list
281   /// is guaranteed to be the tiled operations generated during tiling of the
282   /// generated operation.
283   llvm::SetVector<Operation *> tiledAndFusedOps;
284   /// The `scf.for` operations that iterate over the tiles.
285   SmallVector<LoopLikeOpInterface> loops;
286   /// The replacement values to use for the tiled and fused operations.
287   llvm::DenseMap<Value, Value> replacements;
288 };
289 
290 /// Method to tile and fuse a sequence of operations, by tiling the consumer
291 /// and fusing its producers. Note that this assumes that it is valid to
292 /// tile+fuse the producer into the innermost tiled loop. Its up to the caller
293 /// to ensure that the tile sizes provided make this fusion valid.
294 ///
295 /// For example, for the following sequence
296 ///
297 /// ```mlir
298 /// %0 =
299 /// %1 = linalg.fill ... outs(%0 : ... )
300 /// %2 = linalg.matmul ... outs(%1 : ...) ...
301 /// ```
302 ///
303 /// it is legal to fuse the fill with the matmul only if the matmul is tiled
304 /// along the parallel dimensions and not the reduction dimension, i.e. the tile
305 /// size for the reduction dimension should be 0. The resulting fused
306 /// transformation is
307 ///
308 /// ```mlir
309 /// %1 = scf.for ... iter_args(%arg0 = %0)
310 ///   %2 = tensor.extract_slice %arg0
311 ///   %3 = linalg.fill .. outs(%2 : ... )
312 ///   %4 = linalg.matmul .. outs(%3 : ...)
313 /// }
314 /// ```
315 FailureOr<SCFTileAndFuseResult>
316 tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
317                                      TilingInterface consumer,
318                                      const SCFTileAndFuseOptions &options);
319 
320 /// Fuse the consumer of the source of `candidateSliceOp` by computing the
321 /// required slice of the consumer in-place.  Note that the method
322 /// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
323 /// value but does not delete the slice operation.
324 struct SCFFuseConsumerOfSliceResult {
325   OpOperand *origConsumerOperand; // Original untiled consumer's operand.
326   OpOperand
327       *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
328   SmallVector<Operation *> tiledOps;
329 };
330 FailureOr<scf::SCFFuseConsumerOfSliceResult>
331 tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
332 
333 /// Method to lower an `op` that implements the `TilingInterface` to
334 /// loops/scalars.
335 FailureOr<SmallVector<scf::ForOp>>
336 lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
337 
338 /// Method to tile a reduction and generate a parallel op within a serial loop.
339 /// Each of the partial reductions are calculated in parallel. Then after the
340 /// loop all the partial reduction are merged into a final reduction.
341 /// For example for the following sequence
342 ///
343 /// ```mlir
344 /// %0 = linalg.generic %in ["parallel", "reduction"]
345 ///   : tensor<7x9xf32> -> tensor<7xf32>
346 /// ```
347 ///
348 /// into:
349 ///
350 /// ```mlir
351 /// %0 = linalg.fill ... : tensor<7x4xf32>
352 /// %1 = scf.for ... iter_args(%arg0 = %0)
353 ///   %2 = tensor.extract_slice %arg0 : tensor<7x4xf32> -> tensor<7x?xf32>
354 ///   %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32>
355 ///   %4 = linalg.generic %2, %3 ["parallel", "parallel"]
356 ///     : tensor<7x?xf32> -> tensor<7x?xf32>
357 ///   %5 = tensor.insert_slice %3, %0[0, 0] : tensor<7x4xf32>
358 /// }
359 /// %6 = linalg.generic %1 ["parallel", "reduction"]
360 ///   : tensor<7x4xf32> -> tensor<7xf32>
361 /// ```
362 FailureOr<scf::SCFTilingResult>
363 tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
364                       ArrayRef<OpFoldResult> tileSize);
365 
366 } // namespace scf
367 } // namespace mlir
368 
369 #endif // MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
370