xref: /llvm-project/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===- Transforms.h - MemRef Dialect transformations ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// This header declares functions that assist transformations in the MemRef
10 /// dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
15 #define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
16 
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/ADT/STLFunctionalExtras.h"
19 
20 namespace mlir {
21 class OpBuilder;
22 class RewritePatternSet;
23 class RewriterBase;
24 class Value;
25 class ValueRange;
26 
27 namespace arith {
28 class WideIntEmulationConverter;
29 class NarrowTypeEmulationConverter;
30 } // namespace arith
31 
32 namespace memref {
33 class AllocOp;
34 class AllocaOp;
35 class DeallocOp;
36 
37 //===----------------------------------------------------------------------===//
38 // Patterns
39 //===----------------------------------------------------------------------===//
40 
41 /// Collects a set of patterns to rewrite ops within the memref dialect.
42 void populateExpandOpsPatterns(RewritePatternSet &patterns);
43 
44 /// Appends patterns for folding memref aliasing ops into consumer load/store
45 /// ops into `patterns`.
46 void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
47 
48 /// Appends patterns that resolve `memref.dim` operations with values that are
49 /// defined by operations that implement the
50 /// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input
51 /// operands.
52 void populateResolveRankedShapedTypeResultDimsPatterns(
53     RewritePatternSet &patterns);
54 
55 /// Appends patterns that resolve `memref.dim` operations with values that are
56 /// defined by operations that implement the `InferShapedTypeOpInterface`, in
57 /// terms of shapes of its input operands.
58 void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
59 
60 /// Appends patterns for expanding memref operations that modify the metadata
61 /// (sizes, offset, strides) of a memref into easier to analyze constructs.
62 void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
63 
64 /// Appends patterns for resolving `memref.extract_strided_metadata` into
65 /// `memref.extract_strided_metadata` of its source.
66 void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns);
67 
68 /// Appends patterns for expanding `memref.realloc` operations.
69 void populateExpandReallocPatterns(RewritePatternSet &patterns,
70                                    bool emitDeallocs = true);
71 
72 /// Appends patterns for emulating wide integer memref operations with ops over
73 /// narrower integer types.
74 void populateMemRefWideIntEmulationPatterns(
75     const arith::WideIntEmulationConverter &typeConverter,
76     RewritePatternSet &patterns);
77 
78 /// Appends type conversions for emulating wide integer memref operations with
79 /// ops over narrowe integer types.
80 void populateMemRefWideIntEmulationConversions(
81     arith::WideIntEmulationConverter &typeConverter);
82 
83 /// Appends patterns for emulating memref operations over narrow types with ops
84 /// over wider types.
85 void populateMemRefNarrowTypeEmulationPatterns(
86     const arith::NarrowTypeEmulationConverter &typeConverter,
87     RewritePatternSet &patterns);
88 
89 /// Appends type conversions for emulating memref operations over narrow types
90 /// with ops over wider types.
91 void populateMemRefNarrowTypeEmulationConversions(
92     arith::NarrowTypeEmulationConverter &typeConverter);
93 
94 /// Transformation to do multi-buffering/array expansion to remove dependencies
95 /// on the temporary allocation between consecutive loop iterations.
96 /// It returns the new allocation if the original allocation was multi-buffered
97 /// and returns failure() otherwise.
98 /// When `skipOverrideAnalysis`, the pass will apply the transformation
99 /// without checking thwt the buffer is overrided at the beginning of each
100 /// iteration. This implies that user knows that there is no data carried across
101 /// loop iterations. Example:
102 /// ```
103 /// %0 = memref.alloc() : memref<4x128xf32>
104 /// scf.for %iv = %c1 to %c1024 step %c3 {
105 ///   memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
106 ///   "some_use"(%0) : (memref<4x128xf32>) -> ()
107 /// }
108 /// ```
109 /// into:
110 /// ```
111 /// %0 = memref.alloc() : memref<5x4x128xf32>
112 /// scf.for %iv = %c1 to %c1024 step %c3 {
113 ///   %s = arith.subi %iv, %c1 : index
114 ///   %d = arith.divsi %s, %c3 : index
115 ///   %i = arith.remsi %d, %c5 : index
116 ///   %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
117 ///     memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>>
118 ///   memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>>
119 ///   "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> ()
120 /// }
121 /// ```
122 FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter,
123                                        memref::AllocOp allocOp,
124                                        unsigned multiplier,
125                                        bool skipOverrideAnalysis = false);
126 /// Call into `multiBuffer` with  locally constructed IRRewriter.
127 FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
128                                        unsigned multiplier,
129                                        bool skipOverrideAnalysis = false);
130 
131 /// Appends patterns for extracting address computations from the instructions
132 /// with memory accesses such that these memory accesses use only a base
133 /// pointer.
134 ///
135 /// For instance,
136 /// ```mlir
137 /// memref.load %base[%off0, ...]
138 /// ```
139 ///
140 /// Will be rewritten in:
141 /// ```mlir
142 /// %new_base = memref.subview %base[%off0,...][1,...][1,...]
143 /// memref.load %new_base[%c0,...]
144 /// ```
145 void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
146 
147 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
148 /// given independencies. If the op is already independent of all
149 /// independencies, the same AllocaOp result is returned.
150 ///
151 /// Failure indicates the no suitable upper bound for the dynamic sizes could be
152 /// found.
153 FailureOr<Value> buildIndependentOp(OpBuilder &b, AllocaOp allocaOp,
154                                     ValueRange independencies);
155 
156 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
157 /// given independencies. If the op is already independent of all
158 /// independencies, the same AllocaOp result is returned.
159 ///
160 /// The original AllocaOp is replaced with the new one, wrapped in a SubviewOp.
161 /// The result type of the replacement is different from the original allocation
162 /// type: it has the same shape, but a different layout map. This function
163 /// updates all users that do not have a memref result or memref region block
164 /// argument, and some frequently used memref dialect ops (such as
165 /// memref.subview). It does not update other uses such as the init_arg of an
166 /// scf.for op. Such uses are wrapped in unrealized_conversion_cast.
167 ///
168 /// Failure indicates the no suitable upper bound for the dynamic sizes could be
169 /// found.
170 ///
171 /// Example (make independent of %iv):
172 /// ```
173 /// scf.for %iv = %c0 to %sz step %c1 {
174 ///   %0 = memref.alloca(%iv) : memref<?xf32>
175 ///   %1 = memref.subview %0[0][5][1] : ...
176 ///   linalg.generic outs(%1 : ...) ...
177 ///   %2 = scf.for ... iter_arg(%arg0 = %0) ...
178 ///   ...
179 /// }
180 /// ```
181 ///
182 /// The above IR is rewritten to:
183 ///
184 /// ```
185 /// scf.for %iv = %c0 to %sz step %c1 {
186 ///   %0 = memref.alloca(%sz - 1) : memref<?xf32>
187 ///   %0_subview = memref.subview %0[0][%iv][1]
188 ///       : memref<?xf32> to memref<?xf32, #map>
189 ///   %1 = memref.subview %0_subview[0][5][1] : ...
190 ///   linalg.generic outs(%1 : ...) ...
191 ///   %cast = unrealized_conversion_cast %0_subview
192 ///       : memref<?xf32, #map> to memref<?xf32>
193 ///   %2 = scf.for ... iter_arg(%arg0 = %cast) ...
194 ///  ...
195 /// }
196 /// ```
197 FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
198                                           memref::AllocaOp allocaOp,
199                                           ValueRange independencies);
200 
201 /// Replaces the given `alloc` with the corresponding `alloca` and returns it if
202 /// the following conditions are met:
203 ///   - the corresponding dealloc is available in the same block as the alloc;
204 ///   - the filter, if provided, succeeds on the alloc/dealloc pair.
205 /// Otherwise returns nullptr and leaves the IR unchanged.
206 memref::AllocaOp allocToAlloca(
207     RewriterBase &rewriter, memref::AllocOp alloc,
208     function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
209 
210 } // namespace memref
211 } // namespace mlir
212 
213 #endif
214