1 //===- Transforms.h - MemRef Dialect transformations ------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This header declares functions that assist transformations in the MemRef 10 /// dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H 15 #define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H 16 17 #include "mlir/Support/LLVM.h" 18 #include "llvm/ADT/STLFunctionalExtras.h" 19 20 namespace mlir { 21 class OpBuilder; 22 class RewritePatternSet; 23 class RewriterBase; 24 class Value; 25 class ValueRange; 26 27 namespace arith { 28 class WideIntEmulationConverter; 29 class NarrowTypeEmulationConverter; 30 } // namespace arith 31 32 namespace memref { 33 class AllocOp; 34 class AllocaOp; 35 class DeallocOp; 36 37 //===----------------------------------------------------------------------===// 38 // Patterns 39 //===----------------------------------------------------------------------===// 40 41 /// Collects a set of patterns to rewrite ops within the memref dialect. 42 void populateExpandOpsPatterns(RewritePatternSet &patterns); 43 44 /// Appends patterns for folding memref aliasing ops into consumer load/store 45 /// ops into `patterns`. 46 void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns); 47 48 /// Appends patterns that resolve `memref.dim` operations with values that are 49 /// defined by operations that implement the 50 /// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input 51 /// operands. 52 void populateResolveRankedShapedTypeResultDimsPatterns( 53 RewritePatternSet &patterns); 54 55 /// Appends patterns that resolve `memref.dim` operations with values that are 56 /// defined by operations that implement the `InferShapedTypeOpInterface`, in 57 /// terms of shapes of its input operands. 58 void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns); 59 60 /// Appends patterns for expanding memref operations that modify the metadata 61 /// (sizes, offset, strides) of a memref into easier to analyze constructs. 62 void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns); 63 64 /// Appends patterns for resolving `memref.extract_strided_metadata` into 65 /// `memref.extract_strided_metadata` of its source. 66 void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns); 67 68 /// Appends patterns for expanding `memref.realloc` operations. 69 void populateExpandReallocPatterns(RewritePatternSet &patterns, 70 bool emitDeallocs = true); 71 72 /// Appends patterns for emulating wide integer memref operations with ops over 73 /// narrower integer types. 74 void populateMemRefWideIntEmulationPatterns( 75 const arith::WideIntEmulationConverter &typeConverter, 76 RewritePatternSet &patterns); 77 78 /// Appends type conversions for emulating wide integer memref operations with 79 /// ops over narrowe integer types. 80 void populateMemRefWideIntEmulationConversions( 81 arith::WideIntEmulationConverter &typeConverter); 82 83 /// Appends patterns for emulating memref operations over narrow types with ops 84 /// over wider types. 85 void populateMemRefNarrowTypeEmulationPatterns( 86 const arith::NarrowTypeEmulationConverter &typeConverter, 87 RewritePatternSet &patterns); 88 89 /// Appends type conversions for emulating memref operations over narrow types 90 /// with ops over wider types. 91 void populateMemRefNarrowTypeEmulationConversions( 92 arith::NarrowTypeEmulationConverter &typeConverter); 93 94 /// Transformation to do multi-buffering/array expansion to remove dependencies 95 /// on the temporary allocation between consecutive loop iterations. 96 /// It returns the new allocation if the original allocation was multi-buffered 97 /// and returns failure() otherwise. 98 /// When `skipOverrideAnalysis`, the pass will apply the transformation 99 /// without checking thwt the buffer is overrided at the beginning of each 100 /// iteration. This implies that user knows that there is no data carried across 101 /// loop iterations. Example: 102 /// ``` 103 /// %0 = memref.alloc() : memref<4x128xf32> 104 /// scf.for %iv = %c1 to %c1024 step %c3 { 105 /// memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32> 106 /// "some_use"(%0) : (memref<4x128xf32>) -> () 107 /// } 108 /// ``` 109 /// into: 110 /// ``` 111 /// %0 = memref.alloc() : memref<5x4x128xf32> 112 /// scf.for %iv = %c1 to %c1024 step %c3 { 113 /// %s = arith.subi %iv, %c1 : index 114 /// %d = arith.divsi %s, %c3 : index 115 /// %i = arith.remsi %d, %c5 : index 116 /// %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] : 117 /// memref<5x4x128xf32> to memref<4x128xf32, strided<[128, 1], offset: ?>> 118 /// memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, strided<...>> 119 /// "some_use"(%sv) : (memref<4x128xf32, strided<...>) -> () 120 /// } 121 /// ``` 122 FailureOr<memref::AllocOp> multiBuffer(RewriterBase &rewriter, 123 memref::AllocOp allocOp, 124 unsigned multiplier, 125 bool skipOverrideAnalysis = false); 126 /// Call into `multiBuffer` with locally constructed IRRewriter. 127 FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp, 128 unsigned multiplier, 129 bool skipOverrideAnalysis = false); 130 131 /// Appends patterns for extracting address computations from the instructions 132 /// with memory accesses such that these memory accesses use only a base 133 /// pointer. 134 /// 135 /// For instance, 136 /// ```mlir 137 /// memref.load %base[%off0, ...] 138 /// ``` 139 /// 140 /// Will be rewritten in: 141 /// ```mlir 142 /// %new_base = memref.subview %base[%off0,...][1,...][1,...] 143 /// memref.load %new_base[%c0,...] 144 /// ``` 145 void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns); 146 147 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all 148 /// given independencies. If the op is already independent of all 149 /// independencies, the same AllocaOp result is returned. 150 /// 151 /// Failure indicates the no suitable upper bound for the dynamic sizes could be 152 /// found. 153 FailureOr<Value> buildIndependentOp(OpBuilder &b, AllocaOp allocaOp, 154 ValueRange independencies); 155 156 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all 157 /// given independencies. If the op is already independent of all 158 /// independencies, the same AllocaOp result is returned. 159 /// 160 /// The original AllocaOp is replaced with the new one, wrapped in a SubviewOp. 161 /// The result type of the replacement is different from the original allocation 162 /// type: it has the same shape, but a different layout map. This function 163 /// updates all users that do not have a memref result or memref region block 164 /// argument, and some frequently used memref dialect ops (such as 165 /// memref.subview). It does not update other uses such as the init_arg of an 166 /// scf.for op. Such uses are wrapped in unrealized_conversion_cast. 167 /// 168 /// Failure indicates the no suitable upper bound for the dynamic sizes could be 169 /// found. 170 /// 171 /// Example (make independent of %iv): 172 /// ``` 173 /// scf.for %iv = %c0 to %sz step %c1 { 174 /// %0 = memref.alloca(%iv) : memref<?xf32> 175 /// %1 = memref.subview %0[0][5][1] : ... 176 /// linalg.generic outs(%1 : ...) ... 177 /// %2 = scf.for ... iter_arg(%arg0 = %0) ... 178 /// ... 179 /// } 180 /// ``` 181 /// 182 /// The above IR is rewritten to: 183 /// 184 /// ``` 185 /// scf.for %iv = %c0 to %sz step %c1 { 186 /// %0 = memref.alloca(%sz - 1) : memref<?xf32> 187 /// %0_subview = memref.subview %0[0][%iv][1] 188 /// : memref<?xf32> to memref<?xf32, #map> 189 /// %1 = memref.subview %0_subview[0][5][1] : ... 190 /// linalg.generic outs(%1 : ...) ... 191 /// %cast = unrealized_conversion_cast %0_subview 192 /// : memref<?xf32, #map> to memref<?xf32> 193 /// %2 = scf.for ... iter_arg(%arg0 = %cast) ... 194 /// ... 195 /// } 196 /// ``` 197 FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter, 198 memref::AllocaOp allocaOp, 199 ValueRange independencies); 200 201 /// Replaces the given `alloc` with the corresponding `alloca` and returns it if 202 /// the following conditions are met: 203 /// - the corresponding dealloc is available in the same block as the alloc; 204 /// - the filter, if provided, succeeds on the alloc/dealloc pair. 205 /// Otherwise returns nullptr and leaves the IR unchanged. 206 memref::AllocaOp allocToAlloca( 207 RewriterBase &rewriter, memref::AllocOp alloc, 208 function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr); 209 210 } // namespace memref 211 } // namespace mlir 212 213 #endif 214