1b4d6aadaSOleg Shyshkov //===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===// 2b4d6aadaSOleg Shyshkov // 3b4d6aadaSOleg Shyshkov // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b4d6aadaSOleg Shyshkov // See https://llvm.org/LICENSE.txt for license information. 5b4d6aadaSOleg Shyshkov // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b4d6aadaSOleg Shyshkov // 7b4d6aadaSOleg Shyshkov //===----------------------------------------------------------------------===// 8b4d6aadaSOleg Shyshkov // 9b4d6aadaSOleg Shyshkov // This file implements utilities for the MemRef dialect. 10b4d6aadaSOleg Shyshkov // 11b4d6aadaSOleg Shyshkov //===----------------------------------------------------------------------===// 12b4d6aadaSOleg Shyshkov 13b4d6aadaSOleg Shyshkov #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 148fc433f0SHanhan Wang #include "mlir/Dialect/Affine/IR/AffineOps.h" 158fc433f0SHanhan Wang #include "mlir/Dialect/Arith/Utils/Utils.h" 16b4d6aadaSOleg Shyshkov #include "mlir/Dialect/MemRef/IR/MemRef.h" 17c5dee18bSHanhan Wang #include "mlir/Dialect/Vector/IR/VectorOps.h" 184e2efea5SQuinn Dawkins #include "mlir/Interfaces/ViewLikeInterface.h" 196ed8434eSPrathamesh Tagore #include "llvm/ADT/STLExtras.h" 20b4d6aadaSOleg Shyshkov 21b4d6aadaSOleg Shyshkov namespace mlir { 22b4d6aadaSOleg Shyshkov namespace memref { 23b4d6aadaSOleg Shyshkov 24b4d6aadaSOleg Shyshkov bool isStaticShapeAndContiguousRowMajor(MemRefType type) { 25b4d6aadaSOleg Shyshkov if (!type.hasStaticShape()) 26b4d6aadaSOleg Shyshkov return false; 27b4d6aadaSOleg Shyshkov 28b4d6aadaSOleg Shyshkov SmallVector<int64_t> strides; 29b4d6aadaSOleg Shyshkov int64_t offset; 30*6aaa8f25SMatthias Springer if (failed(type.getStridesAndOffset(strides, offset))) 31b4d6aadaSOleg Shyshkov return false; 32b4d6aadaSOleg Shyshkov 33b4d6aadaSOleg Shyshkov // MemRef is contiguous if outer dimensions are size-1 and inner 34b4d6aadaSOleg Shyshkov // dimensions have unit strides. 35b4d6aadaSOleg Shyshkov int64_t runningStride = 1; 36b4d6aadaSOleg Shyshkov int64_t curDim = strides.size() - 1; 37b4d6aadaSOleg Shyshkov // Finds all inner dimensions with unit strides. 38b4d6aadaSOleg Shyshkov while (curDim >= 0 && strides[curDim] == runningStride) { 39b4d6aadaSOleg Shyshkov runningStride *= type.getDimSize(curDim); 40b4d6aadaSOleg Shyshkov --curDim; 41b4d6aadaSOleg Shyshkov } 42b4d6aadaSOleg Shyshkov 43b4d6aadaSOleg Shyshkov // Check if other dimensions are size-1. 44b4d6aadaSOleg Shyshkov while (curDim >= 0 && type.getDimSize(curDim) == 1) { 45b4d6aadaSOleg Shyshkov --curDim; 46b4d6aadaSOleg Shyshkov } 47b4d6aadaSOleg Shyshkov 48b4d6aadaSOleg Shyshkov // All dims are unit-strided or size-1. 49b4d6aadaSOleg Shyshkov return curDim < 0; 501c983af9SKazu Hirata } 51b4d6aadaSOleg Shyshkov 520f8bab8dSMahesh Ravishankar std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize( 530f8bab8dSMahesh Ravishankar OpBuilder &builder, Location loc, int srcBits, int dstBits, 540f8bab8dSMahesh Ravishankar OpFoldResult offset, ArrayRef<OpFoldResult> sizes, 550f8bab8dSMahesh Ravishankar ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices) { 560f8bab8dSMahesh Ravishankar unsigned sourceRank = sizes.size(); 570f8bab8dSMahesh Ravishankar assert(sizes.size() == strides.size() && 580f8bab8dSMahesh Ravishankar "expected as many sizes as strides for a memref"); 59c730c627SJie Fu SmallVector<OpFoldResult> indicesVec = llvm::to_vector(indices); 600f8bab8dSMahesh Ravishankar if (indices.empty()) 610f8bab8dSMahesh Ravishankar indicesVec.resize(sourceRank, builder.getIndexAttr(0)); 620f8bab8dSMahesh Ravishankar assert(indicesVec.size() == strides.size() && 630f8bab8dSMahesh Ravishankar "expected as many indices as rank of memref"); 648fc433f0SHanhan Wang 658fc433f0SHanhan Wang // Create the affine symbols and values for linearization. 660f8bab8dSMahesh Ravishankar SmallVector<AffineExpr> symbols(2 * sourceRank); 678fc433f0SHanhan Wang bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); 680f8bab8dSMahesh Ravishankar AffineExpr addMulMap = builder.getAffineConstantExpr(0); 690f8bab8dSMahesh Ravishankar AffineExpr mulMap = builder.getAffineConstantExpr(1); 708fc433f0SHanhan Wang 710f8bab8dSMahesh Ravishankar SmallVector<OpFoldResult> offsetValues(2 * sourceRank); 728fc433f0SHanhan Wang 738fc433f0SHanhan Wang for (unsigned i = 0; i < sourceRank; ++i) { 740f8bab8dSMahesh Ravishankar unsigned offsetIdx = 2 * i; 758fc433f0SHanhan Wang addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1]; 760f8bab8dSMahesh Ravishankar offsetValues[offsetIdx] = indicesVec[i]; 770f8bab8dSMahesh Ravishankar offsetValues[offsetIdx + 1] = strides[i]; 788fc433f0SHanhan Wang 790f8bab8dSMahesh Ravishankar mulMap = mulMap * symbols[i]; 808fc433f0SHanhan Wang } 818fc433f0SHanhan Wang 82e3c9c82cSHan-Chung Wang // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits). 830f8bab8dSMahesh Ravishankar int64_t scaler = dstBits / srcBits; 840f8bab8dSMahesh Ravishankar mulMap = mulMap.floorDiv(scaler); 858fc433f0SHanhan Wang 860f8bab8dSMahesh Ravishankar OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply( 872c313259Slialan builder, loc, addMulMap.floorDiv(scaler), offsetValues); 888fc433f0SHanhan Wang OpFoldResult linearizedSize = 890f8bab8dSMahesh Ravishankar affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes); 908fc433f0SHanhan Wang 918fc433f0SHanhan Wang // Adjust baseOffset by the scale factor (dstBits / srcBits). 920f8bab8dSMahesh Ravishankar AffineExpr s0; 930f8bab8dSMahesh Ravishankar bindSymbols(builder.getContext(), s0); 940f8bab8dSMahesh Ravishankar OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply( 950f8bab8dSMahesh Ravishankar builder, loc, s0.floorDiv(scaler), {offset}); 960f8bab8dSMahesh Ravishankar 972c313259Slialan OpFoldResult intraVectorOffset = affine::makeComposedFoldedAffineApply( 982c313259Slialan builder, loc, addMulMap % scaler, offsetValues); 992c313259Slialan 1002c313259Slialan return {{adjustBaseOffset, linearizedSize, intraVectorOffset}, 1012c313259Slialan linearizedIndices}; 1020f8bab8dSMahesh Ravishankar } 1030f8bab8dSMahesh Ravishankar 1040f8bab8dSMahesh Ravishankar LinearizedMemRefInfo 1050f8bab8dSMahesh Ravishankar getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, 1060f8bab8dSMahesh Ravishankar int dstBits, OpFoldResult offset, 1070f8bab8dSMahesh Ravishankar ArrayRef<OpFoldResult> sizes) { 1080f8bab8dSMahesh Ravishankar SmallVector<OpFoldResult> strides(sizes.size()); 1096cde64a9SAdrian Kuegel if (!sizes.empty()) { 1100f8bab8dSMahesh Ravishankar strides.back() = builder.getIndexAttr(1); 1118fc433f0SHanhan Wang AffineExpr s0, s1; 1128fc433f0SHanhan Wang bindSymbols(builder.getContext(), s0, s1); 1130f8bab8dSMahesh Ravishankar for (int index = sizes.size() - 1; index > 0; --index) { 1140f8bab8dSMahesh Ravishankar strides[index - 1] = affine::makeComposedFoldedAffineApply( 1150f8bab8dSMahesh Ravishankar builder, loc, s0 * s1, 1160f8bab8dSMahesh Ravishankar ArrayRef<OpFoldResult>{strides[index], sizes[index]}); 1170f8bab8dSMahesh Ravishankar } 1180f8bab8dSMahesh Ravishankar } 1198fc433f0SHanhan Wang 1200f8bab8dSMahesh Ravishankar LinearizedMemRefInfo linearizedMemRefInfo; 1210f8bab8dSMahesh Ravishankar std::tie(linearizedMemRefInfo, std::ignore) = 1220f8bab8dSMahesh Ravishankar getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset, 1230f8bab8dSMahesh Ravishankar sizes, strides); 1240f8bab8dSMahesh Ravishankar return linearizedMemRefInfo; 1258fc433f0SHanhan Wang } 1268fc433f0SHanhan Wang 127c5dee18bSHanhan Wang /// Returns true if all the uses of op are not read/load. 128c5dee18bSHanhan Wang /// There can be SubviewOp users as long as all its users are also 129c5dee18bSHanhan Wang /// StoreOp/transfer_write. If return true it also fills out the uses, if it 130c5dee18bSHanhan Wang /// returns false uses is unchanged. 131c5dee18bSHanhan Wang static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) { 132c5dee18bSHanhan Wang std::vector<Operation *> opUses; 133c5dee18bSHanhan Wang for (OpOperand &use : op->getUses()) { 134c5dee18bSHanhan Wang Operation *useOp = use.getOwner(); 135c5dee18bSHanhan Wang if (isa<memref::DeallocOp>(useOp) || 136c5dee18bSHanhan Wang (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 && 137c5dee18bSHanhan Wang !mlir::hasEffect<MemoryEffects::Read>(useOp)) || 138c5dee18bSHanhan Wang (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) { 139c5dee18bSHanhan Wang opUses.push_back(useOp); 140c5dee18bSHanhan Wang continue; 141c5dee18bSHanhan Wang } 142c5dee18bSHanhan Wang return false; 143c5dee18bSHanhan Wang } 144c5dee18bSHanhan Wang uses.insert(uses.end(), opUses.begin(), opUses.end()); 145c5dee18bSHanhan Wang return true; 146c5dee18bSHanhan Wang } 147c5dee18bSHanhan Wang 148c5dee18bSHanhan Wang void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) { 149c5dee18bSHanhan Wang std::vector<Operation *> opToErase; 150c5dee18bSHanhan Wang parentOp->walk([&](memref::AllocOp op) { 151c5dee18bSHanhan Wang std::vector<Operation *> candidates; 152c5dee18bSHanhan Wang if (resultIsNotRead(op, candidates)) { 153c5dee18bSHanhan Wang opToErase.insert(opToErase.end(), candidates.begin(), candidates.end()); 154c5dee18bSHanhan Wang opToErase.push_back(op.getOperation()); 155c5dee18bSHanhan Wang } 156c5dee18bSHanhan Wang }); 157c5dee18bSHanhan Wang for (Operation *op : opToErase) 158c5dee18bSHanhan Wang rewriter.eraseOp(op); 159c5dee18bSHanhan Wang } 160c5dee18bSHanhan Wang 1616ed8434eSPrathamesh Tagore static SmallVector<OpFoldResult> 1626ed8434eSPrathamesh Tagore computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, 1636ed8434eSPrathamesh Tagore ArrayRef<OpFoldResult> sizes, 1646ed8434eSPrathamesh Tagore OpFoldResult unit) { 1656ed8434eSPrathamesh Tagore SmallVector<OpFoldResult> strides(sizes.size(), unit); 1666ed8434eSPrathamesh Tagore AffineExpr s0, s1; 1676ed8434eSPrathamesh Tagore bindSymbols(builder.getContext(), s0, s1); 1686ed8434eSPrathamesh Tagore 1696ed8434eSPrathamesh Tagore for (int64_t r = strides.size() - 1; r > 0; --r) { 1706ed8434eSPrathamesh Tagore strides[r - 1] = affine::makeComposedFoldedAffineApply( 1716ed8434eSPrathamesh Tagore builder, loc, s0 * s1, {strides[r], sizes[r]}); 1726ed8434eSPrathamesh Tagore } 1736ed8434eSPrathamesh Tagore return strides; 1746ed8434eSPrathamesh Tagore } 1756ed8434eSPrathamesh Tagore 1766ed8434eSPrathamesh Tagore SmallVector<OpFoldResult> 1776ed8434eSPrathamesh Tagore computeSuffixProductIRBlock(Location loc, OpBuilder &builder, 1786ed8434eSPrathamesh Tagore ArrayRef<OpFoldResult> sizes) { 1796ed8434eSPrathamesh Tagore OpFoldResult unit = builder.getIndexAttr(1); 1806ed8434eSPrathamesh Tagore return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit); 1816ed8434eSPrathamesh Tagore } 1826ed8434eSPrathamesh Tagore 18390d2f8c6SBenjamin Maxwell MemrefValue skipFullyAliasingOperations(MemrefValue source) { 18490d2f8c6SBenjamin Maxwell while (auto op = source.getDefiningOp()) { 18590d2f8c6SBenjamin Maxwell if (auto subViewOp = dyn_cast<memref::SubViewOp>(op); 18690d2f8c6SBenjamin Maxwell subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) { 18790d2f8c6SBenjamin Maxwell // A `memref.subview` with an all zero offset, and all unit strides, still 18890d2f8c6SBenjamin Maxwell // points to the same memory. 18990d2f8c6SBenjamin Maxwell source = cast<MemrefValue>(subViewOp.getSource()); 19090d2f8c6SBenjamin Maxwell } else if (auto castOp = dyn_cast<memref::CastOp>(op)) { 19190d2f8c6SBenjamin Maxwell // A `memref.cast` still points to the same memory. 19290d2f8c6SBenjamin Maxwell source = castOp.getSource(); 19390d2f8c6SBenjamin Maxwell } else { 19490d2f8c6SBenjamin Maxwell return source; 19590d2f8c6SBenjamin Maxwell } 19690d2f8c6SBenjamin Maxwell } 19790d2f8c6SBenjamin Maxwell return source; 19890d2f8c6SBenjamin Maxwell } 19990d2f8c6SBenjamin Maxwell 2004e2efea5SQuinn Dawkins MemrefValue skipViewLikeOps(MemrefValue source) { 20190d2f8c6SBenjamin Maxwell while (auto op = source.getDefiningOp()) { 2024e2efea5SQuinn Dawkins if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) { 2034e2efea5SQuinn Dawkins source = cast<MemrefValue>(viewLike.getViewSource()); 2044e2efea5SQuinn Dawkins continue; 20590d2f8c6SBenjamin Maxwell } 2064e2efea5SQuinn Dawkins return source; 20790d2f8c6SBenjamin Maxwell } 20890d2f8c6SBenjamin Maxwell return source; 20990d2f8c6SBenjamin Maxwell } 21090d2f8c6SBenjamin Maxwell 211b4d6aadaSOleg Shyshkov } // namespace memref 212b4d6aadaSOleg Shyshkov } // namespace mlir 213