xref: /llvm-project/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (revision a95ad2da36b6a996b05c79df6b385cd98bac286d)
1 //===- Transforms.h - Tensor Transformation Patterns ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
10 #define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
11 
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Interfaces/ViewLikeInterface.h"
15 
16 namespace mlir {
17 
18 struct TilingResult;
19 
20 namespace tensor {
21 
22 //===----------------------------------------------------------------------===//
23 // Patterns
24 //===----------------------------------------------------------------------===//
25 
26 /// Method to swap an `tensor.extract_slice` with its producer when the
27 /// producer implements the `TilingInterface`. The pattern itself does not
28 /// provide a mechanism to control where the application happens. With use of
29 /// transform dialect that control is done within the transform dialect. Other
30 /// use cases can inherit from this pattern and add necessary controls.
31 FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
32     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
33 
34 /// Method to swap an `tensor.insert_slice` with its consumer when the
35 /// consumer implements the `TilingInterface`.
36 FailureOr<TilingResult>
37 replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
38                                     OffsetSizeAndStrideOpInterface sliceOp,
39                                     OpOperand &consumerOp);
40 
41 //===----------------------------------------------------------------------===//
42 // Populate functions.
43 //===----------------------------------------------------------------------===//
44 
45 /// Appends patterns for folding tensor subset ops into consumer load/store
46 /// ops into `patterns`. (This includes patterns for folding tensor subset ops
47 /// into vector transfer ops.)
48 void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns);
49 
50 /// Appends patterns for folding tensor subset ops into vector transfer ops.
51 void populateFoldTensorSubsetIntoVectorTransferPatterns(
52     RewritePatternSet &patterns);
53 
54 /// Collects patterns to merge consecutive tensor.insert_slice/extract_slice
55 /// into one. These patterns are in this separate entry point because the
56 /// bufferization is sensitive to IR structure, particularly those
57 /// tensor.extract_slice and tensor.insert_slice ops for creating the slices.
58 void populateMergeConsecutiveInsertExtractSlicePatterns(
59     RewritePatternSet &patterns);
60 
61 /// Populates `patterns` with patterns that drop redundant tensor.insert_slice
62 /// rank expansions.
63 void populateDropRedundantInsertSliceRankExpansionPatterns(
64     RewritePatternSet &patterns);
65 
66 /// Populates `patterns` with patterns that fold `tensor.expand_shape` and
67 /// `tensor.collapse_shape` into other ops.
68 void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
69 
70 /// Populates `patterns` with patterns that bubble up `tensor.expand_shape`
71 /// through `tensor.collapse_shape` ops.
72 void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns);
73 
74 /// Populates `patterns` with patterns that fold tensor.empty with its
75 /// consumers.
76 ///
77 /// If `singleUseOnly` is set to "true", only tensor.empty ops with a single
78 /// use are folded.
79 void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
80                                      bool foldSingleUseOnly = false);
81 
82 /// Populates `patterns` with patterns that decompose `tensor.concat` into
83 /// `tensor.empty` of a tensor of the concatenated size, followed by a chain
84 /// of `tensor.insert_slice` operations on the inputs. This is intended to be
85 /// used as a fallback tensor -> tensor lowering that decomposes concat such
86 /// that it can be bufferized into a sequence of copies.
87 void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
88 
89 /// Populates `patterns` with patterns that simplify `tensor.pack` and
90 /// `tensor.unpack` operations.
91 void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
92 
93 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
94 /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
95 /// respectively.
96 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
97 
98 using ControlFoldFn = std::function<bool(OpOperand *)>;
99 
100 /// Populates `patterns` with patterns that replace tensor ops (such as
101 /// tensor.generate) with constants when possible.
102 void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
103                                        const ControlFoldFn &controlFn);
104 
105 //===----------------------------------------------------------------------===//
106 // Transform helpers
107 //===----------------------------------------------------------------------===//
108 
109 /// Build a new tensor::PadOp with low/high padding that is independent of all
110 /// given independencies. If the op is already independent of all
111 /// independencies, the same PadOp result is returned.
112 ///
113 /// Failure indicates the no suitable upper bound for low/high padding could be
114 /// found.
115 ///
116 /// Example:
117 /// scf.for %iv = %lb to %ub step %step {
118 ///   %high = affine.apply affine_map<(d0)[s0] -> (s0 - d0)> (%i)[%ub]
119 ///   %p = tensor.pad %t low[5] high[%high] ...
120 ///   ...
121 /// }
122 ///
123 /// The function builds IR such as:
124 /// %high_new = affine.apply affine_map<()[s0, s1] -> (-s0 + s1)> ()[%lb, %ub]
125 /// %p_hoistable = tensor.pad %t low[5] high[%high_new]
126 /// %dim = tensor.dim %t, %c0
127 /// %size = affine.apply affine_map<(d0)[s0, s1] -> (-d0 + s0 + s1 + 5)>
128 ///     (%iv)[%ub, %dim]
129 /// %slice = tensor.extract_slice %p_hoistable [0] [%size] [1]
130 ///
131 /// The slice is returned.
132 FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::PadOp padOp,
133                                     ValueRange independencies);
134 
135 /// Build a new tensor::EmptyOp who's dynamic sizes are independent of all
136 /// given independencies. If the op is already independent of all
137 /// independencies, the same EmptyOp result is returned.
138 ///
139 /// Failure indicates the no suitable upper bound for the dynamic sizes could be
140 /// found.
141 FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::EmptyOp emptyOp,
142                                     ValueRange independencies);
143 
144 } // namespace tensor
145 } // namespace mlir
146 
147 #endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H
148