1 //===- MemRefUtils.h - MemRef transformation utilities ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines prototypes for various transformation utilities for 10 // the MemRefOps dialect. These are not passes by themselves but are used 11 // either by passes, optimization sequences, or in turn by other transformation 12 // utilities. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #ifndef MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H 17 #define MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H 18 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 21 namespace mlir { 22 23 class MemRefType; 24 25 /// A value with a memref type. 26 using MemrefValue = TypedValue<BaseMemRefType>; 27 28 namespace memref { 29 30 /// Returns true, if the memref type has static shapes and represents a 31 /// contiguous chunk of memory. 32 bool isStaticShapeAndContiguousRowMajor(MemRefType type); 33 34 /// For a `memref` with `offset`, `sizes` and `strides`, returns the 35 /// offset, size, and potentially the size padded at the front to use for the 36 /// linearized `memref`. 37 /// - If the linearization is done for emulating load/stores of 38 /// element type with bitwidth `srcBits` using element type with 39 /// bitwidth `dstBits`, the linearized offset and size are 40 /// scaled down by `dstBits`/`srcBits`. 41 /// - If `indices` is provided, it represents the position in the 42 /// original `memref` being accessed. The method then returns the 43 /// index to use in the linearized `memref`. The linearized index 44 /// is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided 45 /// 0, is returned for the linearized index. 46 /// - If the size of the load/store is smaller than the linearized memref 47 /// load/store, the memory region emulated is larger than the actual memory 48 /// region needed. `intraDataOffset` returns the element offset of the data 49 /// relevant at the beginning. 50 struct LinearizedMemRefInfo { 51 OpFoldResult linearizedOffset; 52 OpFoldResult linearizedSize; 53 OpFoldResult intraDataOffset; 54 }; 55 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize( 56 OpBuilder &builder, Location loc, int srcBits, int dstBits, 57 OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 58 ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices = {}); 59 60 /// For a `memref` with `offset` and `sizes`, returns the 61 /// offset and size to use for the linearized `memref`, assuming that 62 /// the strides are computed from a row-major ordering of the sizes; 63 /// - If the linearization is done for emulating load/stores of 64 /// element type with bitwidth `srcBits` using element type with 65 /// bitwidth `dstBits`, the linearized offset and size are 66 /// scaled down by `dstBits`/`srcBits`. 67 LinearizedMemRefInfo 68 getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, 69 int dstBits, OpFoldResult offset, 70 ArrayRef<OpFoldResult> sizes); 71 72 /// Track temporary allocations that are never read from. If this is the case 73 /// it means both the allocations and associated stores can be removed. 74 void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp); 75 76 /// Given a set of sizes, return the suffix product. 77 /// 78 /// When applied to slicing, this is the calculation needed to derive the 79 /// strides (i.e. the number of linear indices to skip along the (k-1) most 80 /// minor dimensions to get the next k-slice). 81 /// 82 /// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`. 83 /// 84 /// Assuming `sizes` is `[s0, .. sn]`, return the vector<Value> 85 /// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`. 86 /// 87 /// It is the caller's responsibility to provide valid OpFoldResult type values 88 /// and construct valid IR in the end. 89 /// 90 /// `sizes` elements are asserted to be non-negative. 91 /// 92 /// Return an empty vector if `sizes` is empty. 93 /// 94 /// The function emits an IR block which computes suffix product for provided 95 /// sizes. 96 SmallVector<OpFoldResult> 97 computeSuffixProductIRBlock(Location loc, OpBuilder &builder, 98 ArrayRef<OpFoldResult> sizes); 99 inline SmallVector<OpFoldResult> 100 computeStridesIRBlock(Location loc, OpBuilder &builder, 101 ArrayRef<OpFoldResult> sizes) { 102 return computeSuffixProductIRBlock(loc, builder, sizes); 103 } 104 105 /// Walk up the source chain until an operation that changes/defines the view of 106 /// memory is found (i.e. skip operations that alias the entire view). 107 MemrefValue skipFullyAliasingOperations(MemrefValue source); 108 109 /// Checks if two (memref) values are the same or statically known to alias 110 /// the same region of memory. 111 inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) { 112 return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b); 113 } 114 115 /// Walk up the source chain until we find an operation that is not a view of 116 /// the source memref (i.e. implements ViewLikeOpInterface). 117 MemrefValue skipViewLikeOps(MemrefValue source); 118 119 } // namespace memref 120 } // namespace mlir 121 122 #endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H 123