xref: /llvm-project/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td (revision 7982bc340bf7e627ddc552e941d8b0493ab9248d)
1//===-- Passes.td - MemRef transformation definition file --*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
10#define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
11
12include "mlir/Pass/PassBase.td"
13
14def ExpandOps : Pass<"memref-expand"> {
15  let summary = "Legalize memref operations to be convertible to LLVM.";
16  let constructor = "mlir::memref::createExpandOpsPass()";
17}
18
19def FoldMemRefAliasOps : Pass<"fold-memref-alias-ops"> {
20  let summary = "Fold memref alias ops into consumer load/store ops";
21  let description = [{
22    The pass folds loading/storing from/to memref aliasing ops to loading/storing
23    from/to the original memref.
24  }];
25  let constructor = "mlir::memref::createFoldMemRefAliasOpsPass()";
26  let dependentDialects = [
27      "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
28  ];
29}
30
31def MemRefEmulateWideInt : Pass<"memref-emulate-wide-int"> {
32  let summary = "Emulate 2*N-bit integer operations using N-bit operations";
33  let description = [{
34    Emulate memref integer operations that use too wide integer types with
35    equivalent operations on supported narrow integer types. This is done by
36    splitting original integer values into two halves.
37
38    Currently, only power-of-two integer bitwidths are supported.
39  }];
40  let options = [
41    Option<"widestIntSupported", "widest-int-supported", "unsigned",
42           /*default=*/"32", "Widest integer type supported by the target">,
43  ];
44  let dependentDialects = ["vector::VectorDialect"];
45}
46
47def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
48  let summary = "Normalize memrefs";
49   let description = [{
50    This pass transforms memref types with a non-trivial
51    [layout map](https://mlir.llvm.org/docs/Dialects/Builtin/#affine-map-layout)
52    into memref types with an identity layout map, e.g. (i, j) -> (i, j). This
53    pass is inter-procedural, in the sense that it can modify function
54    interfaces and call sites that pass memref types. In order to modify
55    memref types while preserving the original behavior, users of those
56    memref types are also modified to incorporate the resulting layout map.
57    For instance, an [AffineLoadOp](https://mlir.llvm.org/docs/Dialects/Affine/#affineload-mliraffineloadop)
58    will be updated to compose the layout map with with the affine expression
59    contained in the op. Operations marked with the
60    [MemRefsNormalizable](https://mlir.llvm.org/docs/Traits/#memrefsnormalizable)
61    trait are expected to be normalizable. Supported operations include affine
62    operations, memref.alloc, memref.dealloc, and func.return.
63
64    Given an appropriate layout map specified in the code, this transformation
65    can express tiled or linearized access to multi-dimensional data
66    structures, but will not modify memref types without an explicit layout
67    map.
68
69    Currently this pass is limited to only modify
70    functions where all memref types can be normalized. If a function
71    contains any operations that are not MemRefNormalizable, then the function
72    and any functions that call or call it will not be modified.
73
74    Input
75
76    ```mlir
77    #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
78    func.func @matmul(%A: memref<16xf64, #tile>,
79                 %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
80      affine.for %arg3 = 0 to 16 {
81            %a = affine.load %A[%arg3] : memref<16xf64, #tile>
82            %p = arith.mulf %a, %a : f64
83            affine.store %p, %A[%arg3] : memref<16xf64, #tile>
84      }
85      %c = memref.alloc() : memref<16xf64, #tile>
86      %d = affine.load %c[0] : memref<16xf64, #tile>
87      return %A: memref<16xf64, #tile>
88    }
89    ```
90
91    Output
92
93    ```mlir
94    func.func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
95      -> memref<4x4xf64> {
96      affine.for %arg3 = 0 to 16 {
97        %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
98        %4 = arith.mulf %3, %3 : f64
99        affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
100      }
101      %0 = memref.alloc() : memref<4x4xf64>
102      %1 = affine.apply #map1()
103      %2 = affine.load %0[0, 0] : memref<4x4xf64>
104      return %arg0 : memref<4x4xf64>
105    }
106    ```
107
108    Input
109
110    ```
111    #linear8 = affine_map<(i, j) -> (i * 8 + j)>
112    func.func @linearize(%arg0: memref<8x8xi32, #linear8>,
113                    %arg1: memref<8x8xi32, #linear8>,
114                    %arg2: memref<8x8xi32, #linear8>) {
115      %c8 = arith.constant 8 : index
116      %c0 = arith.constant 0 : index
117      %c1 = arith.constant 1 : index
118      affine.for %arg3 = %c0 to %c8  {
119      affine.for %arg4 = %c0 to %c8  {
120        affine.for %arg5 = %c0 to %c8 {
121          %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
122          %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
123          %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
124          %3 = arith.muli %0, %1 : i32
125          %4 = arith.addi %2, %3 : i32
126          affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
127        }
128      }
129      }
130      return
131    }
132    ```
133
134    Output
135
136    ```mlir
137    func.func @linearize(%arg0: memref<64xi32>,
138                    %arg1: memref<64xi32>,
139                    %arg2: memref<64xi32>) {
140    %c8 = arith.constant 8 : index
141    %c0 = arith.constant 0 : index
142    affine.for %arg3 = %c0 to %c8 {
143      affine.for %arg4 = %c0 to %c8 {
144        affine.for %arg5 = %c0 to %c8 {
145          %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
146          %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
147          %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
148          %3 = arith.muli %0, %1 : i32
149          %4 = arith.addi %2, %3 : i32
150          affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
151        }
152      }
153    }
154    return
155  }
156  ```
157  }];
158  let constructor = "mlir::memref::createNormalizeMemRefsPass()";
159  let dependentDialects = ["affine::AffineDialect"];
160}
161
162def ResolveRankedShapeTypeResultDims :
163    Pass<"resolve-ranked-shaped-type-result-dims"> {
164  let summary = "Resolve memref.dim of result values of ranked shape type";
165  let description = [{
166    The pass resolves memref.dim of result of operations that
167    implement the `ReifyRankedShapedTypeOpInterface` in terms of
168    shapes of its operands.
169  }];
170  let constructor =
171      "mlir::memref::createResolveRankedShapeTypeResultDimsPass()";
172  let dependentDialects = [
173    "memref::MemRefDialect", "tensor::TensorDialect"
174  ];
175}
176
177def ResolveShapedTypeResultDims : Pass<"resolve-shaped-type-result-dims"> {
178  let summary = "Resolve memref.dim of result values";
179  let description = [{
180    The pass resolves memref.dim of result of operations that
181    implement the `InferShapedTypeOpInterface` or
182    `ReifyRankedShapedTypeOpInterface` in terms of shapes of its
183    operands.
184  }];
185  let constructor = "mlir::memref::createResolveShapedTypeResultDimsPass()";
186  let dependentDialects = [
187    "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
188  ];
189}
190
191def ExpandStridedMetadata : Pass<"expand-strided-metadata"> {
192  let summary = "Expand memref operations into easier to analyze constructs";
193  let description = [{
194    The pass expands memref operations that modify the metadata of a memref
195    (sizes, offset, strides) into a sequence of easier to analyze constructs.
196    In particular, this pass transforms operations into explicit sequence of
197    operations that model the effect of this operation on the different metadata.
198    This pass uses affine constructs to materialize these effects.
199
200    Supported ops include:
201
202    - `memref.collapse_shape`
203    - `memref.expand_shape`
204    - `memref.extract_aligned_pointer_as_index`
205    - `memref.extract_strided_metadata`
206    - `memref.subview`
207  }];
208  let constructor = "mlir::memref::createExpandStridedMetadataPass()";
209  let dependentDialects = [
210      "affine::AffineDialect", "memref::MemRefDialect"
211  ];
212}
213
214def ExpandRealloc : Pass<"expand-realloc"> {
215  let summary = "Expand memref.realloc operations into its components";
216  let description = [{
217    The `memref.realloc` operation performs a conditional allocation and copy to
218    increase the size of a buffer if necessary. This pass converts a `realloc`
219    operation into this sequence of simpler operations such that other passes
220    at a later stage in the compilation pipeline do not have to consider the
221    `realloc` operation anymore (e.g., the buffer deallocation pass and the
222    conversion pass to LLVM).
223
224    Example of an expansion:
225    ```mlir
226    %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
227    ```
228    is expanded to
229    ```mlir
230    %c0 = arith.constant 0 : index
231    %dim = memref.dim %alloc, %c0 : memref<?xf32>
232    %is_old_smaller = arith.cmpi ult, %dim, %arg1
233    %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
234      %new_alloc = memref.alloc(%size) : memref<?xf32>
235      %subview = memref.subview %new_alloc[0] [%dim] [1]
236      memref.copy %alloc, %subview
237      memref.dealloc %alloc
238      scf.yield %alloc_0 : memref<?xf32>
239    } else {
240      %reinterpret_cast = memref.reinterpret_cast %alloc to
241        offset: [0], sizes: [%size], strides: [1]
242      scf.yield %reinterpret_cast : memref<?xf32>
243    }
244    ```
245  }];
246  let options = [
247    Option<"emitDeallocs", "emit-deallocs", "bool", /*default=*/"true",
248           "Emit deallocation operations for the original MemRef">,
249  ];
250  let constructor = "mlir::memref::createExpandReallocPass()";
251  let dependentDialects = [
252      "arith::ArithDialect", "scf::SCFDialect", "memref::MemRefDialect"
253  ];
254}
255
256#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
257
258