xref: /llvm-project/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1b4d6aadaSOleg Shyshkov //===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
2b4d6aadaSOleg Shyshkov //
3b4d6aadaSOleg Shyshkov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b4d6aadaSOleg Shyshkov // See https://llvm.org/LICENSE.txt for license information.
5b4d6aadaSOleg Shyshkov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b4d6aadaSOleg Shyshkov //
7b4d6aadaSOleg Shyshkov //===----------------------------------------------------------------------===//
8b4d6aadaSOleg Shyshkov //
9b4d6aadaSOleg Shyshkov // This file implements utilities for the MemRef dialect.
10b4d6aadaSOleg Shyshkov //
11b4d6aadaSOleg Shyshkov //===----------------------------------------------------------------------===//
12b4d6aadaSOleg Shyshkov 
13b4d6aadaSOleg Shyshkov #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
148fc433f0SHanhan Wang #include "mlir/Dialect/Affine/IR/AffineOps.h"
158fc433f0SHanhan Wang #include "mlir/Dialect/Arith/Utils/Utils.h"
16b4d6aadaSOleg Shyshkov #include "mlir/Dialect/MemRef/IR/MemRef.h"
17c5dee18bSHanhan Wang #include "mlir/Dialect/Vector/IR/VectorOps.h"
184e2efea5SQuinn Dawkins #include "mlir/Interfaces/ViewLikeInterface.h"
196ed8434eSPrathamesh Tagore #include "llvm/ADT/STLExtras.h"
20b4d6aadaSOleg Shyshkov 
21b4d6aadaSOleg Shyshkov namespace mlir {
22b4d6aadaSOleg Shyshkov namespace memref {
23b4d6aadaSOleg Shyshkov 
24b4d6aadaSOleg Shyshkov bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
25b4d6aadaSOleg Shyshkov   if (!type.hasStaticShape())
26b4d6aadaSOleg Shyshkov     return false;
27b4d6aadaSOleg Shyshkov 
28b4d6aadaSOleg Shyshkov   SmallVector<int64_t> strides;
29b4d6aadaSOleg Shyshkov   int64_t offset;
30*6aaa8f25SMatthias Springer   if (failed(type.getStridesAndOffset(strides, offset)))
31b4d6aadaSOleg Shyshkov     return false;
32b4d6aadaSOleg Shyshkov 
33b4d6aadaSOleg Shyshkov   // MemRef is contiguous if outer dimensions are size-1 and inner
34b4d6aadaSOleg Shyshkov   // dimensions have unit strides.
35b4d6aadaSOleg Shyshkov   int64_t runningStride = 1;
36b4d6aadaSOleg Shyshkov   int64_t curDim = strides.size() - 1;
37b4d6aadaSOleg Shyshkov   // Finds all inner dimensions with unit strides.
38b4d6aadaSOleg Shyshkov   while (curDim >= 0 && strides[curDim] == runningStride) {
39b4d6aadaSOleg Shyshkov     runningStride *= type.getDimSize(curDim);
40b4d6aadaSOleg Shyshkov     --curDim;
41b4d6aadaSOleg Shyshkov   }
42b4d6aadaSOleg Shyshkov 
43b4d6aadaSOleg Shyshkov   // Check if other dimensions are size-1.
44b4d6aadaSOleg Shyshkov   while (curDim >= 0 && type.getDimSize(curDim) == 1) {
45b4d6aadaSOleg Shyshkov     --curDim;
46b4d6aadaSOleg Shyshkov   }
47b4d6aadaSOleg Shyshkov 
48b4d6aadaSOleg Shyshkov   // All dims are unit-strided or size-1.
49b4d6aadaSOleg Shyshkov   return curDim < 0;
501c983af9SKazu Hirata }
51b4d6aadaSOleg Shyshkov 
520f8bab8dSMahesh Ravishankar std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
530f8bab8dSMahesh Ravishankar     OpBuilder &builder, Location loc, int srcBits, int dstBits,
540f8bab8dSMahesh Ravishankar     OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
550f8bab8dSMahesh Ravishankar     ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices) {
560f8bab8dSMahesh Ravishankar   unsigned sourceRank = sizes.size();
570f8bab8dSMahesh Ravishankar   assert(sizes.size() == strides.size() &&
580f8bab8dSMahesh Ravishankar          "expected as many sizes as strides for a memref");
59c730c627SJie Fu   SmallVector<OpFoldResult> indicesVec = llvm::to_vector(indices);
600f8bab8dSMahesh Ravishankar   if (indices.empty())
610f8bab8dSMahesh Ravishankar     indicesVec.resize(sourceRank, builder.getIndexAttr(0));
620f8bab8dSMahesh Ravishankar   assert(indicesVec.size() == strides.size() &&
630f8bab8dSMahesh Ravishankar          "expected as many indices as rank of memref");
648fc433f0SHanhan Wang 
658fc433f0SHanhan Wang   // Create the affine symbols and values for linearization.
660f8bab8dSMahesh Ravishankar   SmallVector<AffineExpr> symbols(2 * sourceRank);
678fc433f0SHanhan Wang   bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
680f8bab8dSMahesh Ravishankar   AffineExpr addMulMap = builder.getAffineConstantExpr(0);
690f8bab8dSMahesh Ravishankar   AffineExpr mulMap = builder.getAffineConstantExpr(1);
708fc433f0SHanhan Wang 
710f8bab8dSMahesh Ravishankar   SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
728fc433f0SHanhan Wang 
738fc433f0SHanhan Wang   for (unsigned i = 0; i < sourceRank; ++i) {
740f8bab8dSMahesh Ravishankar     unsigned offsetIdx = 2 * i;
758fc433f0SHanhan Wang     addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
760f8bab8dSMahesh Ravishankar     offsetValues[offsetIdx] = indicesVec[i];
770f8bab8dSMahesh Ravishankar     offsetValues[offsetIdx + 1] = strides[i];
788fc433f0SHanhan Wang 
790f8bab8dSMahesh Ravishankar     mulMap = mulMap * symbols[i];
808fc433f0SHanhan Wang   }
818fc433f0SHanhan Wang 
82e3c9c82cSHan-Chung Wang   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
830f8bab8dSMahesh Ravishankar   int64_t scaler = dstBits / srcBits;
840f8bab8dSMahesh Ravishankar   mulMap = mulMap.floorDiv(scaler);
858fc433f0SHanhan Wang 
860f8bab8dSMahesh Ravishankar   OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
872c313259Slialan       builder, loc, addMulMap.floorDiv(scaler), offsetValues);
888fc433f0SHanhan Wang   OpFoldResult linearizedSize =
890f8bab8dSMahesh Ravishankar       affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
908fc433f0SHanhan Wang 
918fc433f0SHanhan Wang   // Adjust baseOffset by the scale factor (dstBits / srcBits).
920f8bab8dSMahesh Ravishankar   AffineExpr s0;
930f8bab8dSMahesh Ravishankar   bindSymbols(builder.getContext(), s0);
940f8bab8dSMahesh Ravishankar   OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
950f8bab8dSMahesh Ravishankar       builder, loc, s0.floorDiv(scaler), {offset});
960f8bab8dSMahesh Ravishankar 
972c313259Slialan   OpFoldResult intraVectorOffset = affine::makeComposedFoldedAffineApply(
982c313259Slialan       builder, loc, addMulMap % scaler, offsetValues);
992c313259Slialan 
1002c313259Slialan   return {{adjustBaseOffset, linearizedSize, intraVectorOffset},
1012c313259Slialan           linearizedIndices};
1020f8bab8dSMahesh Ravishankar }
1030f8bab8dSMahesh Ravishankar 
1040f8bab8dSMahesh Ravishankar LinearizedMemRefInfo
1050f8bab8dSMahesh Ravishankar getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
1060f8bab8dSMahesh Ravishankar                                  int dstBits, OpFoldResult offset,
1070f8bab8dSMahesh Ravishankar                                  ArrayRef<OpFoldResult> sizes) {
1080f8bab8dSMahesh Ravishankar   SmallVector<OpFoldResult> strides(sizes.size());
1096cde64a9SAdrian Kuegel   if (!sizes.empty()) {
1100f8bab8dSMahesh Ravishankar     strides.back() = builder.getIndexAttr(1);
1118fc433f0SHanhan Wang     AffineExpr s0, s1;
1128fc433f0SHanhan Wang     bindSymbols(builder.getContext(), s0, s1);
1130f8bab8dSMahesh Ravishankar     for (int index = sizes.size() - 1; index > 0; --index) {
1140f8bab8dSMahesh Ravishankar       strides[index - 1] = affine::makeComposedFoldedAffineApply(
1150f8bab8dSMahesh Ravishankar           builder, loc, s0 * s1,
1160f8bab8dSMahesh Ravishankar           ArrayRef<OpFoldResult>{strides[index], sizes[index]});
1170f8bab8dSMahesh Ravishankar     }
1180f8bab8dSMahesh Ravishankar   }
1198fc433f0SHanhan Wang 
1200f8bab8dSMahesh Ravishankar   LinearizedMemRefInfo linearizedMemRefInfo;
1210f8bab8dSMahesh Ravishankar   std::tie(linearizedMemRefInfo, std::ignore) =
1220f8bab8dSMahesh Ravishankar       getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
1230f8bab8dSMahesh Ravishankar                                        sizes, strides);
1240f8bab8dSMahesh Ravishankar   return linearizedMemRefInfo;
1258fc433f0SHanhan Wang }
1268fc433f0SHanhan Wang 
127c5dee18bSHanhan Wang /// Returns true if all the uses of op are not read/load.
128c5dee18bSHanhan Wang /// There can be SubviewOp users as long as all its users are also
129c5dee18bSHanhan Wang /// StoreOp/transfer_write. If return true it also fills out the uses, if it
130c5dee18bSHanhan Wang /// returns false uses is unchanged.
131c5dee18bSHanhan Wang static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
132c5dee18bSHanhan Wang   std::vector<Operation *> opUses;
133c5dee18bSHanhan Wang   for (OpOperand &use : op->getUses()) {
134c5dee18bSHanhan Wang     Operation *useOp = use.getOwner();
135c5dee18bSHanhan Wang     if (isa<memref::DeallocOp>(useOp) ||
136c5dee18bSHanhan Wang         (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
137c5dee18bSHanhan Wang          !mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
138c5dee18bSHanhan Wang         (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
139c5dee18bSHanhan Wang       opUses.push_back(useOp);
140c5dee18bSHanhan Wang       continue;
141c5dee18bSHanhan Wang     }
142c5dee18bSHanhan Wang     return false;
143c5dee18bSHanhan Wang   }
144c5dee18bSHanhan Wang   uses.insert(uses.end(), opUses.begin(), opUses.end());
145c5dee18bSHanhan Wang   return true;
146c5dee18bSHanhan Wang }
147c5dee18bSHanhan Wang 
148c5dee18bSHanhan Wang void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
149c5dee18bSHanhan Wang   std::vector<Operation *> opToErase;
150c5dee18bSHanhan Wang   parentOp->walk([&](memref::AllocOp op) {
151c5dee18bSHanhan Wang     std::vector<Operation *> candidates;
152c5dee18bSHanhan Wang     if (resultIsNotRead(op, candidates)) {
153c5dee18bSHanhan Wang       opToErase.insert(opToErase.end(), candidates.begin(), candidates.end());
154c5dee18bSHanhan Wang       opToErase.push_back(op.getOperation());
155c5dee18bSHanhan Wang     }
156c5dee18bSHanhan Wang   });
157c5dee18bSHanhan Wang   for (Operation *op : opToErase)
158c5dee18bSHanhan Wang     rewriter.eraseOp(op);
159c5dee18bSHanhan Wang }
160c5dee18bSHanhan Wang 
1616ed8434eSPrathamesh Tagore static SmallVector<OpFoldResult>
1626ed8434eSPrathamesh Tagore computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder,
1636ed8434eSPrathamesh Tagore                                 ArrayRef<OpFoldResult> sizes,
1646ed8434eSPrathamesh Tagore                                 OpFoldResult unit) {
1656ed8434eSPrathamesh Tagore   SmallVector<OpFoldResult> strides(sizes.size(), unit);
1666ed8434eSPrathamesh Tagore   AffineExpr s0, s1;
1676ed8434eSPrathamesh Tagore   bindSymbols(builder.getContext(), s0, s1);
1686ed8434eSPrathamesh Tagore 
1696ed8434eSPrathamesh Tagore   for (int64_t r = strides.size() - 1; r > 0; --r) {
1706ed8434eSPrathamesh Tagore     strides[r - 1] = affine::makeComposedFoldedAffineApply(
1716ed8434eSPrathamesh Tagore         builder, loc, s0 * s1, {strides[r], sizes[r]});
1726ed8434eSPrathamesh Tagore   }
1736ed8434eSPrathamesh Tagore   return strides;
1746ed8434eSPrathamesh Tagore }
1756ed8434eSPrathamesh Tagore 
1766ed8434eSPrathamesh Tagore SmallVector<OpFoldResult>
1776ed8434eSPrathamesh Tagore computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
1786ed8434eSPrathamesh Tagore                             ArrayRef<OpFoldResult> sizes) {
1796ed8434eSPrathamesh Tagore   OpFoldResult unit = builder.getIndexAttr(1);
1806ed8434eSPrathamesh Tagore   return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
1816ed8434eSPrathamesh Tagore }
1826ed8434eSPrathamesh Tagore 
18390d2f8c6SBenjamin Maxwell MemrefValue skipFullyAliasingOperations(MemrefValue source) {
18490d2f8c6SBenjamin Maxwell   while (auto op = source.getDefiningOp()) {
18590d2f8c6SBenjamin Maxwell     if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
18690d2f8c6SBenjamin Maxwell         subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
18790d2f8c6SBenjamin Maxwell       // A `memref.subview` with an all zero offset, and all unit strides, still
18890d2f8c6SBenjamin Maxwell       // points to the same memory.
18990d2f8c6SBenjamin Maxwell       source = cast<MemrefValue>(subViewOp.getSource());
19090d2f8c6SBenjamin Maxwell     } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
19190d2f8c6SBenjamin Maxwell       // A `memref.cast` still points to the same memory.
19290d2f8c6SBenjamin Maxwell       source = castOp.getSource();
19390d2f8c6SBenjamin Maxwell     } else {
19490d2f8c6SBenjamin Maxwell       return source;
19590d2f8c6SBenjamin Maxwell     }
19690d2f8c6SBenjamin Maxwell   }
19790d2f8c6SBenjamin Maxwell   return source;
19890d2f8c6SBenjamin Maxwell }
19990d2f8c6SBenjamin Maxwell 
2004e2efea5SQuinn Dawkins MemrefValue skipViewLikeOps(MemrefValue source) {
20190d2f8c6SBenjamin Maxwell   while (auto op = source.getDefiningOp()) {
2024e2efea5SQuinn Dawkins     if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
2034e2efea5SQuinn Dawkins       source = cast<MemrefValue>(viewLike.getViewSource());
2044e2efea5SQuinn Dawkins       continue;
20590d2f8c6SBenjamin Maxwell     }
2064e2efea5SQuinn Dawkins     return source;
20790d2f8c6SBenjamin Maxwell   }
20890d2f8c6SBenjamin Maxwell   return source;
20990d2f8c6SBenjamin Maxwell }
21090d2f8c6SBenjamin Maxwell 
211b4d6aadaSOleg Shyshkov } // namespace memref
212b4d6aadaSOleg Shyshkov } // namespace mlir
213