xref: /llvm-project/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h (revision 8ff2da782d676edddc19d856a853c1ebab999fc2)
1 //===- MemRefUtils.h - MemRef transformation utilities ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines prototypes for various transformation utilities for
10 // the MemRefOps dialect. These are not passes by themselves but are used
11 // either by passes, optimization sequences, or in turn by other transformation
12 // utilities.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
17 #define MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
18 
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 
21 namespace mlir {
22 
23 class MemRefType;
24 
25 /// A value with a memref type.
26 using MemrefValue = TypedValue<BaseMemRefType>;
27 
28 namespace memref {
29 
30 /// Returns true, if the memref type has static shapes and represents a
31 /// contiguous chunk of memory.
32 bool isStaticShapeAndContiguousRowMajor(MemRefType type);
33 
34 /// For a `memref` with `offset`, `sizes` and `strides`, returns the
35 /// offset, size, and potentially the size padded at the front to use for the
36 /// linearized `memref`.
37 /// - If the linearization is done for emulating load/stores of
38 ///   element type with bitwidth `srcBits` using element type with
39 ///   bitwidth `dstBits`, the linearized offset and size are
40 ///   scaled down by `dstBits`/`srcBits`.
41 /// - If `indices` is provided, it represents the position in the
42 ///   original `memref` being accessed. The method then returns the
43 ///   index to use in the linearized `memref`. The linearized index
44 ///   is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
45 ///   0, is returned for the linearized index.
46 /// - If the size of the load/store is smaller than the linearized memref
47 ///   load/store, the memory region emulated is larger than the actual memory
48 ///   region needed. `intraDataOffset` returns the element offset of the data
49 ///   relevant at the beginning.
50 struct LinearizedMemRefInfo {
51   OpFoldResult linearizedOffset;
52   OpFoldResult linearizedSize;
53   OpFoldResult intraDataOffset;
54 };
55 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
56     OpBuilder &builder, Location loc, int srcBits, int dstBits,
57     OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
58     ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices = {});
59 
60 /// For a `memref` with `offset` and `sizes`, returns the
61 /// offset and size to use for the linearized `memref`, assuming that
62 /// the strides are computed from a row-major ordering of the sizes;
63 /// - If the linearization is done for emulating load/stores of
64 ///   element type with bitwidth `srcBits` using element type with
65 ///   bitwidth `dstBits`, the linearized offset and size are
66 ///   scaled down by `dstBits`/`srcBits`.
67 LinearizedMemRefInfo
68 getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
69                                  int dstBits, OpFoldResult offset,
70                                  ArrayRef<OpFoldResult> sizes);
71 
72 /// Track temporary allocations that are never read from. If this is the case
73 /// it means both the allocations and associated stores can be removed.
74 void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp);
75 
76 /// Given a set of sizes, return the suffix product.
77 ///
78 /// When applied to slicing, this is the calculation needed to derive the
79 /// strides (i.e. the number of linear indices to skip along the (k-1) most
80 /// minor dimensions to get the next k-slice).
81 ///
82 /// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`.
83 ///
84 /// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value>
85 ///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
86 ///
87 /// It is the caller's responsibility to provide valid OpFoldResult type values
88 /// and construct valid IR in the end.
89 ///
90 /// `sizes` elements are asserted to be non-negative.
91 ///
92 /// Return an empty vector if `sizes` is empty.
93 ///
94 /// The function emits an IR block which computes suffix product for provided
95 /// sizes.
96 SmallVector<OpFoldResult>
97 computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
98                             ArrayRef<OpFoldResult> sizes);
99 inline SmallVector<OpFoldResult>
100 computeStridesIRBlock(Location loc, OpBuilder &builder,
101                       ArrayRef<OpFoldResult> sizes) {
102   return computeSuffixProductIRBlock(loc, builder, sizes);
103 }
104 
105 /// Walk up the source chain until an operation that changes/defines the view of
106 /// memory is found (i.e. skip operations that alias the entire view).
107 MemrefValue skipFullyAliasingOperations(MemrefValue source);
108 
109 /// Checks if two (memref) values are the same or statically known to alias
110 /// the same region of memory.
111 inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
112   return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
113 }
114 
115 /// Walk up the source chain until we find an operation that is not a view of
116 /// the source memref (i.e. implements ViewLikeOpInterface).
117 MemrefValue skipViewLikeOps(MemrefValue source);
118 
119 } // namespace memref
120 } // namespace mlir
121 
122 #endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
123