11b002d27SArnab Dutta //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===// 21b002d27SArnab Dutta // 31b002d27SArnab Dutta // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 41b002d27SArnab Dutta // See https://llvm.org/LICENSE.txt for license information. 51b002d27SArnab Dutta // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 61b002d27SArnab Dutta // 71b002d27SArnab Dutta //===----------------------------------------------------------------------===// 81b002d27SArnab Dutta // 91b002d27SArnab Dutta // This transformation pass folds loading/storing from/to subview ops into 101b002d27SArnab Dutta // loading/storing from/to the original memref. 111b002d27SArnab Dutta // 121b002d27SArnab Dutta //===----------------------------------------------------------------------===// 131b002d27SArnab Dutta 141b002d27SArnab Dutta #include "mlir/Dialect/Affine/IR/AffineOps.h" 154dc72d47SNicolas Vasilache #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" 16abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 17829446cbSNicolas Vasilache #include "mlir/Dialect/Arith/Utils/Utils.h" 18829446cbSNicolas Vasilache #include "mlir/Dialect/GPU/IR/GPUDialect.h" 191b002d27SArnab Dutta #include "mlir/Dialect/MemRef/IR/MemRef.h" 20829446cbSNicolas Vasilache #include "mlir/Dialect/MemRef/Transforms/Passes.h" 21faafd26cSQuentin Colombet #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 226ed8434eSPrathamesh Tagore #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 23fc5c1a76SManish Gupta #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 241b002d27SArnab Dutta #include "mlir/Dialect/Utils/IndexingUtils.h" 251b002d27SArnab Dutta #include "mlir/Dialect/Vector/IR/VectorOps.h" 264dc72d47SNicolas Vasilache #include "mlir/IR/AffineMap.h" 271b002d27SArnab Dutta #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 284dc72d47SNicolas Vasilache #include "llvm/ADT/STLExtras.h" 291b002d27SArnab Dutta #include "llvm/ADT/SmallBitVector.h" 301b002d27SArnab Dutta #include "llvm/ADT/TypeSwitch.h" 31fc5c1a76SManish Gupta #include "llvm/Support/Debug.h" 32fc5c1a76SManish Gupta 33fc5c1a76SManish Gupta #define DEBUG_TYPE "fold-memref-alias-ops" 34fc5c1a76SManish Gupta #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 351b002d27SArnab Dutta 3667d0d7acSMichele Scuttari namespace mlir { 3767d0d7acSMichele Scuttari namespace memref { 3867d0d7acSMichele Scuttari #define GEN_PASS_DEF_FOLDMEMREFALIASOPS 3967d0d7acSMichele Scuttari #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 4067d0d7acSMichele Scuttari } // namespace memref 4167d0d7acSMichele Scuttari } // namespace mlir 4267d0d7acSMichele Scuttari 431b002d27SArnab Dutta using namespace mlir; 441b002d27SArnab Dutta 451b002d27SArnab Dutta //===----------------------------------------------------------------------===// 461b002d27SArnab Dutta // Utility functions 471b002d27SArnab Dutta //===----------------------------------------------------------------------===// 481b002d27SArnab Dutta 491b002d27SArnab Dutta /// Given the 'indices' of a load/store operation where the memref is a result 501b002d27SArnab Dutta /// of a expand_shape op, returns the indices w.r.t to the source memref of the 511b002d27SArnab Dutta /// expand_shape op. For example 521b002d27SArnab Dutta /// 531b002d27SArnab Dutta /// %0 = ... : memref<12x42xf32> 541b002d27SArnab Dutta /// %1 = memref.expand_shape %0 [[0, 1], [2]] 551b002d27SArnab Dutta /// : memref<12x42xf32> into memref<2x6x42xf32> 561b002d27SArnab Dutta /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32 571b002d27SArnab Dutta /// 581b002d27SArnab Dutta /// could be folded into 591b002d27SArnab Dutta /// 601b002d27SArnab Dutta /// %2 = load %0[6 * i1 + i2, %i3] : 611b002d27SArnab Dutta /// memref<12x42xf32> 621b002d27SArnab Dutta static LogicalResult 631b002d27SArnab Dutta resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, 641b002d27SArnab Dutta memref::ExpandShapeOp expandShapeOp, 651b002d27SArnab Dutta ValueRange indices, 661b002d27SArnab Dutta SmallVectorImpl<Value> &sourceIndices) { 676ed8434eSPrathamesh Tagore // Record the rewriter context for constructing ops later. 686ed8434eSPrathamesh Tagore MLIRContext *ctx = rewriter.getContext(); 696ed8434eSPrathamesh Tagore 706ed8434eSPrathamesh Tagore // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`. 716ed8434eSPrathamesh Tagore // This is done for the purpose of inferring the output shape via 726ed8434eSPrathamesh Tagore // `inferExpandOutputShape` which will in turn be used for suffix product 736ed8434eSPrathamesh Tagore // calculation later. 746ed8434eSPrathamesh Tagore SmallVector<OpFoldResult> srcShape; 756ed8434eSPrathamesh Tagore MemRefType srcType = expandShapeOp.getSrcType(); 766ed8434eSPrathamesh Tagore 776ed8434eSPrathamesh Tagore for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) { 786ed8434eSPrathamesh Tagore if (srcType.isDynamicDim(i)) { 796ed8434eSPrathamesh Tagore srcShape.push_back( 806ed8434eSPrathamesh Tagore rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i) 816ed8434eSPrathamesh Tagore .getResult()); 826ed8434eSPrathamesh Tagore } else { 836ed8434eSPrathamesh Tagore srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i])); 846ed8434eSPrathamesh Tagore } 856ed8434eSPrathamesh Tagore } 866ed8434eSPrathamesh Tagore 876ed8434eSPrathamesh Tagore auto outputShape = inferExpandShapeOutputShape( 886ed8434eSPrathamesh Tagore rewriter, loc, expandShapeOp.getResultType(), 896ed8434eSPrathamesh Tagore expandShapeOp.getReassociationIndices(), srcShape); 906ed8434eSPrathamesh Tagore if (!outputShape.has_value()) 91f6897c37SHanhan Wang return failure(); 92f6897c37SHanhan Wang 936ed8434eSPrathamesh Tagore // Traverse all reassociation groups to determine the appropriate indices 946ed8434eSPrathamesh Tagore // corresponding to each one of them post op folding. 95203fad47SNicolas Vasilache for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) { 961b002d27SArnab Dutta assert(!groups.empty() && "association indices groups cannot be empty"); 976ed8434eSPrathamesh Tagore // Flag to indicate the presence of dynamic dimensions in current 986ed8434eSPrathamesh Tagore // reassociation group. 99203fad47SNicolas Vasilache int64_t groupSize = groups.size(); 100203fad47SNicolas Vasilache 1016ed8434eSPrathamesh Tagore // Group output dimensions utilized in this reassociation group for suffix 1026ed8434eSPrathamesh Tagore // product calculation. 1036ed8434eSPrathamesh Tagore SmallVector<OpFoldResult> sizesVal(groupSize); 1046ed8434eSPrathamesh Tagore for (int64_t i = 0; i < groupSize; ++i) { 1056ed8434eSPrathamesh Tagore sizesVal[i] = (*outputShape)[groups[i]]; 1066ed8434eSPrathamesh Tagore } 107203fad47SNicolas Vasilache 1086ed8434eSPrathamesh Tagore // Calculate suffix product of relevant output dimension sizes. 1096ed8434eSPrathamesh Tagore SmallVector<OpFoldResult> suffixProduct = 1106ed8434eSPrathamesh Tagore memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal); 1116ed8434eSPrathamesh Tagore 1126ed8434eSPrathamesh Tagore // Create affine expression variables for dimensions and symbols in the 1136ed8434eSPrathamesh Tagore // newly constructed affine map. 1146ed8434eSPrathamesh Tagore SmallVector<AffineExpr> dims(groupSize), symbols(groupSize); 1156ed8434eSPrathamesh Tagore bindDimsList<AffineExpr>(ctx, dims); 1166ed8434eSPrathamesh Tagore bindSymbolsList<AffineExpr>(ctx, symbols); 1176ed8434eSPrathamesh Tagore 1186ed8434eSPrathamesh Tagore // Linearize binded dimensions and symbols to construct the resultant 1196ed8434eSPrathamesh Tagore // affine expression for this indice. 1206ed8434eSPrathamesh Tagore AffineExpr srcIndexExpr = linearize(ctx, dims, symbols); 1216ed8434eSPrathamesh Tagore 1226ed8434eSPrathamesh Tagore // Record the load index corresponding to each dimension in the 1236ed8434eSPrathamesh Tagore // reassociation group. These are later supplied as operands to the affine 1246ed8434eSPrathamesh Tagore // map used for calulating relevant index post op folding. 125829446cbSNicolas Vasilache SmallVector<OpFoldResult> dynamicIndices(groupSize); 126203fad47SNicolas Vasilache for (int64_t i = 0; i < groupSize; i++) 127203fad47SNicolas Vasilache dynamicIndices[i] = indices[groups[i]]; 128829446cbSNicolas Vasilache 1296ed8434eSPrathamesh Tagore // Supply suffix product results followed by load op indices as operands 1306ed8434eSPrathamesh Tagore // to the map. 1316ed8434eSPrathamesh Tagore SmallVector<OpFoldResult> mapOperands; 1326ed8434eSPrathamesh Tagore llvm::append_range(mapOperands, suffixProduct); 1336ed8434eSPrathamesh Tagore llvm::append_range(mapOperands, dynamicIndices); 1346ed8434eSPrathamesh Tagore 1356ed8434eSPrathamesh Tagore // Creating maximally folded and composed affine.apply composes better 1366ed8434eSPrathamesh Tagore // with other transformations without interleaving canonicalization 1376ed8434eSPrathamesh Tagore // passes. 1384c48f016SMatthias Springer OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 139829446cbSNicolas Vasilache rewriter, loc, 140829446cbSNicolas Vasilache AffineMap::get(/*numDims=*/groupSize, 1416ed8434eSPrathamesh Tagore /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr), 1426ed8434eSPrathamesh Tagore mapOperands); 1436ed8434eSPrathamesh Tagore 1446ed8434eSPrathamesh Tagore // Push index value in the op post folding corresponding to this 1456ed8434eSPrathamesh Tagore // reassociation group. 146829446cbSNicolas Vasilache sourceIndices.push_back( 147829446cbSNicolas Vasilache getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); 1481b002d27SArnab Dutta } 1491b002d27SArnab Dutta return success(); 1501b002d27SArnab Dutta } 1511b002d27SArnab Dutta 1521b002d27SArnab Dutta /// Given the 'indices' of a load/store operation where the memref is a result 1531b002d27SArnab Dutta /// of a collapse_shape op, returns the indices w.r.t to the source memref of 1541b002d27SArnab Dutta /// the collapse_shape op. For example 1551b002d27SArnab Dutta /// 1561b002d27SArnab Dutta /// %0 = ... : memref<2x6x42xf32> 1571b002d27SArnab Dutta /// %1 = memref.collapse_shape %0 [[0, 1], [2]] 1581b002d27SArnab Dutta /// : memref<2x6x42xf32> into memref<12x42xf32> 1591b002d27SArnab Dutta /// %2 = load %1[%i1, %i2] : memref<12x42xf32> 1601b002d27SArnab Dutta /// 1611b002d27SArnab Dutta /// could be folded into 1621b002d27SArnab Dutta /// 1631b002d27SArnab Dutta /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] : 1641b002d27SArnab Dutta /// memref<2x6x42xf32> 1651b002d27SArnab Dutta static LogicalResult 1661b002d27SArnab Dutta resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, 1671b002d27SArnab Dutta memref::CollapseShapeOp collapseShapeOp, 1681b002d27SArnab Dutta ValueRange indices, 1691b002d27SArnab Dutta SmallVectorImpl<Value> &sourceIndices) { 170203fad47SNicolas Vasilache int64_t cnt = 0; 1711b002d27SArnab Dutta SmallVector<Value> tmp(indices.size()); 172829446cbSNicolas Vasilache SmallVector<OpFoldResult> dynamicIndices; 173203fad47SNicolas Vasilache for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) { 1741b002d27SArnab Dutta assert(!groups.empty() && "association indices groups cannot be empty"); 1751b002d27SArnab Dutta dynamicIndices.push_back(indices[cnt++]); 176203fad47SNicolas Vasilache int64_t groupSize = groups.size(); 177203fad47SNicolas Vasilache 178f32b3e1cSFelix Schneider // Calculate suffix product for all collapse op source dimension sizes 179f32b3e1cSFelix Schneider // except the most major one of each group. 180f32b3e1cSFelix Schneider // We allow the most major source dimension to be dynamic but enforce all 181f32b3e1cSFelix Schneider // others to be known statically. 182f32b3e1cSFelix Schneider SmallVector<int64_t> sizes(groupSize, 1); 183f32b3e1cSFelix Schneider for (int64_t i = 1; i < groupSize; ++i) { 184203fad47SNicolas Vasilache sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]); 185f32b3e1cSFelix Schneider if (sizes[i] == ShapedType::kDynamic) 186f32b3e1cSFelix Schneider return failure(); 187f32b3e1cSFelix Schneider } 188203fad47SNicolas Vasilache SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes); 189203fad47SNicolas Vasilache 1901b002d27SArnab Dutta // Derive the index values along all dimensions of the source corresponding 1911b002d27SArnab Dutta // to the index wrt to collapsed shape op output. 192203fad47SNicolas Vasilache auto d0 = rewriter.getAffineDimExpr(0); 193203fad47SNicolas Vasilache SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct); 194203fad47SNicolas Vasilache 195203fad47SNicolas Vasilache // Construct the AffineApplyOp for each delinearizingExpr. 196829446cbSNicolas Vasilache for (int64_t i = 0; i < groupSize; i++) { 1974c48f016SMatthias Springer OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 198829446cbSNicolas Vasilache rewriter, loc, 199203fad47SNicolas Vasilache AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, 200203fad47SNicolas Vasilache delinearizingExprs[i]), 201829446cbSNicolas Vasilache dynamicIndices); 202829446cbSNicolas Vasilache sourceIndices.push_back( 203829446cbSNicolas Vasilache getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); 204829446cbSNicolas Vasilache } 2051b002d27SArnab Dutta dynamicIndices.clear(); 2061b002d27SArnab Dutta } 2071b002d27SArnab Dutta if (collapseShapeOp.getReassociationIndices().empty()) { 2081b002d27SArnab Dutta auto zeroAffineMap = rewriter.getConstantAffineMap(0); 209203fad47SNicolas Vasilache int64_t srcRank = 2105550c821STres Popp cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank(); 211829446cbSNicolas Vasilache for (int64_t i = 0; i < srcRank; i++) { 2124c48f016SMatthias Springer OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 213829446cbSNicolas Vasilache rewriter, loc, zeroAffineMap, dynamicIndices); 2141b002d27SArnab Dutta sourceIndices.push_back( 215829446cbSNicolas Vasilache getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); 216829446cbSNicolas Vasilache } 2171b002d27SArnab Dutta } 2181b002d27SArnab Dutta return success(); 2191b002d27SArnab Dutta } 2201b002d27SArnab Dutta 2211b002d27SArnab Dutta /// Helpers to access the memref operand for each op. 2221b002d27SArnab Dutta template <typename LoadOrStoreOpTy> 2231b002d27SArnab Dutta static Value getMemRefOperand(LoadOrStoreOpTy op) { 2241b002d27SArnab Dutta return op.getMemref(); 2251b002d27SArnab Dutta } 2261b002d27SArnab Dutta 2271b002d27SArnab Dutta static Value getMemRefOperand(vector::TransferReadOp op) { 2281b002d27SArnab Dutta return op.getSource(); 2291b002d27SArnab Dutta } 2301b002d27SArnab Dutta 23146c32afbSGuray Ozen static Value getMemRefOperand(nvgpu::LdMatrixOp op) { 23246c32afbSGuray Ozen return op.getSrcMemref(); 23346c32afbSGuray Ozen } 23446c32afbSGuray Ozen 2355ec360c5SGuray Ozen static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } 2365ec360c5SGuray Ozen 237dae3c44cSMax191 static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); } 238dae3c44cSMax191 2395aa2c65aStyb0807 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); } 2405aa2c65aStyb0807 241dae3c44cSMax191 static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); } 242dae3c44cSMax191 2431b002d27SArnab Dutta static Value getMemRefOperand(vector::TransferWriteOp op) { 2441b002d27SArnab Dutta return op.getSource(); 2451b002d27SArnab Dutta } 2461b002d27SArnab Dutta 24759e4fbfcSLei Zhang static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) { 24859e4fbfcSLei Zhang return op.getSrcMemref(); 24959e4fbfcSLei Zhang } 25059e4fbfcSLei Zhang 25159e4fbfcSLei Zhang static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) { 25259e4fbfcSLei Zhang return op.getDstMemref(); 25359e4fbfcSLei Zhang } 25459e4fbfcSLei Zhang 2551b002d27SArnab Dutta //===----------------------------------------------------------------------===// 2561b002d27SArnab Dutta // Patterns 2571b002d27SArnab Dutta //===----------------------------------------------------------------------===// 2581b002d27SArnab Dutta 2591b002d27SArnab Dutta namespace { 2601b002d27SArnab Dutta /// Merges subview operation with load/transferRead operation. 2611b002d27SArnab Dutta template <typename OpTy> 2621b002d27SArnab Dutta class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { 2631b002d27SArnab Dutta public: 2641b002d27SArnab Dutta using OpRewritePattern<OpTy>::OpRewritePattern; 2651b002d27SArnab Dutta 2661b002d27SArnab Dutta LogicalResult matchAndRewrite(OpTy loadOp, 2671b002d27SArnab Dutta PatternRewriter &rewriter) const override; 2681b002d27SArnab Dutta }; 2691b002d27SArnab Dutta 2701b002d27SArnab Dutta /// Merges expand_shape operation with load/transferRead operation. 2711b002d27SArnab Dutta template <typename OpTy> 2721b002d27SArnab Dutta class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { 2731b002d27SArnab Dutta public: 2741b002d27SArnab Dutta using OpRewritePattern<OpTy>::OpRewritePattern; 2751b002d27SArnab Dutta 2761b002d27SArnab Dutta LogicalResult matchAndRewrite(OpTy loadOp, 2771b002d27SArnab Dutta PatternRewriter &rewriter) const override; 2781b002d27SArnab Dutta }; 2791b002d27SArnab Dutta 2801b002d27SArnab Dutta /// Merges collapse_shape operation with load/transferRead operation. 2811b002d27SArnab Dutta template <typename OpTy> 2821b002d27SArnab Dutta class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { 2831b002d27SArnab Dutta public: 2841b002d27SArnab Dutta using OpRewritePattern<OpTy>::OpRewritePattern; 2851b002d27SArnab Dutta 2861b002d27SArnab Dutta LogicalResult matchAndRewrite(OpTy loadOp, 2871b002d27SArnab Dutta PatternRewriter &rewriter) const override; 2881b002d27SArnab Dutta }; 2891b002d27SArnab Dutta 2901b002d27SArnab Dutta /// Merges subview operation with store/transferWriteOp operation. 2911b002d27SArnab Dutta template <typename OpTy> 2921b002d27SArnab Dutta class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { 2931b002d27SArnab Dutta public: 2941b002d27SArnab Dutta using OpRewritePattern<OpTy>::OpRewritePattern; 2951b002d27SArnab Dutta 2961b002d27SArnab Dutta LogicalResult matchAndRewrite(OpTy storeOp, 2971b002d27SArnab Dutta PatternRewriter &rewriter) const override; 2981b002d27SArnab Dutta }; 2991b002d27SArnab Dutta 3001b002d27SArnab Dutta /// Merges expand_shape operation with store/transferWriteOp operation. 3011b002d27SArnab Dutta template <typename OpTy> 3021b002d27SArnab Dutta class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { 3031b002d27SArnab Dutta public: 3041b002d27SArnab Dutta using OpRewritePattern<OpTy>::OpRewritePattern; 3051b002d27SArnab Dutta 3061b002d27SArnab Dutta LogicalResult matchAndRewrite(OpTy storeOp, 3071b002d27SArnab Dutta PatternRewriter &rewriter) const override; 3081b002d27SArnab Dutta }; 3091b002d27SArnab Dutta 3101b002d27SArnab Dutta /// Merges collapse_shape operation with store/transferWriteOp operation. 3111b002d27SArnab Dutta template <typename OpTy> 3121b002d27SArnab Dutta class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { 3131b002d27SArnab Dutta public: 3141b002d27SArnab Dutta using OpRewritePattern<OpTy>::OpRewritePattern; 3151b002d27SArnab Dutta 3161b002d27SArnab Dutta LogicalResult matchAndRewrite(OpTy storeOp, 3171b002d27SArnab Dutta PatternRewriter &rewriter) const override; 3181b002d27SArnab Dutta }; 3191b002d27SArnab Dutta 320ccb8a4e3SMatthias Springer /// Folds subview(subview(x)) to a single subview(x). 321ccb8a4e3SMatthias Springer class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> { 322ccb8a4e3SMatthias Springer public: 323ccb8a4e3SMatthias Springer using OpRewritePattern<memref::SubViewOp>::OpRewritePattern; 324ccb8a4e3SMatthias Springer 325ccb8a4e3SMatthias Springer LogicalResult matchAndRewrite(memref::SubViewOp subView, 326ccb8a4e3SMatthias Springer PatternRewriter &rewriter) const override { 327ccb8a4e3SMatthias Springer auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>(); 328ccb8a4e3SMatthias Springer if (!srcSubView) 329ccb8a4e3SMatthias Springer return failure(); 330ccb8a4e3SMatthias Springer 33133468a51SNicolas Vasilache // TODO: relax unit stride assumption. 33233468a51SNicolas Vasilache if (!subView.hasUnitStride()) { 33333468a51SNicolas Vasilache return rewriter.notifyMatchFailure(subView, "requires unit strides"); 334ccb8a4e3SMatthias Springer } 33533468a51SNicolas Vasilache if (!srcSubView.hasUnitStride()) { 33633468a51SNicolas Vasilache return rewriter.notifyMatchFailure(srcSubView, "requires unit strides"); 337ccb8a4e3SMatthias Springer } 338ccb8a4e3SMatthias Springer 33933468a51SNicolas Vasilache // Resolve sizes according to dropped dims. 34033468a51SNicolas Vasilache SmallVector<OpFoldResult> resolvedSizes; 34133468a51SNicolas Vasilache llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims(); 3424c48f016SMatthias Springer affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(), 34333468a51SNicolas Vasilache subView.getMixedSizes(), srcDroppedDims, 34433468a51SNicolas Vasilache resolvedSizes); 34533468a51SNicolas Vasilache 34633468a51SNicolas Vasilache // Resolve offsets according to source offsets and strides. 34733468a51SNicolas Vasilache SmallVector<Value> resolvedOffsets; 3484c48f016SMatthias Springer affine::resolveIndicesIntoOpWithOffsetsAndStrides( 34933468a51SNicolas Vasilache rewriter, subView.getLoc(), srcSubView.getMixedOffsets(), 35033468a51SNicolas Vasilache srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(), 35133468a51SNicolas Vasilache resolvedOffsets); 35233468a51SNicolas Vasilache 353ccb8a4e3SMatthias Springer // Replace original op. 354ccb8a4e3SMatthias Springer rewriter.replaceOpWithNewOp<memref::SubViewOp>( 35533468a51SNicolas Vasilache subView, subView.getType(), srcSubView.getSource(), 35633468a51SNicolas Vasilache getAsOpFoldResult(resolvedOffsets), resolvedSizes, 35733468a51SNicolas Vasilache srcSubView.getMixedStrides()); 35833468a51SNicolas Vasilache 359ccb8a4e3SMatthias Springer return success(); 360ccb8a4e3SMatthias Springer } 361ccb8a4e3SMatthias Springer }; 362fc5c1a76SManish Gupta 363fc5c1a76SManish Gupta /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern 364fc5c1a76SManish Gupta /// is folds subview on src and dst memref of the copy. 365baa5beecStyb0807 class NVGPUAsyncCopyOpSubViewOpFolder final 366fc5c1a76SManish Gupta : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> { 367fc5c1a76SManish Gupta public: 368fc5c1a76SManish Gupta using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern; 369fc5c1a76SManish Gupta 370fc5c1a76SManish Gupta LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp, 371fc5c1a76SManish Gupta PatternRewriter &rewriter) const override; 372fc5c1a76SManish Gupta }; 3731b002d27SArnab Dutta } // namespace 3741b002d27SArnab Dutta 3751b002d27SArnab Dutta static SmallVector<Value> 3762fe37d1cSMehdi Amini calculateExpandedAccessIndices(AffineMap affineMap, 3772fe37d1cSMehdi Amini const SmallVector<Value> &indices, Location loc, 3782fe37d1cSMehdi Amini PatternRewriter &rewriter) { 379829446cbSNicolas Vasilache SmallVector<OpFoldResult> indicesOfr(llvm::to_vector( 380829446cbSNicolas Vasilache llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }))); 3811b002d27SArnab Dutta SmallVector<Value> expandedIndices; 382829446cbSNicolas Vasilache for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) { 3834c48f016SMatthias Springer OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 384829446cbSNicolas Vasilache rewriter, loc, affineMap.getSubMap({i}), indicesOfr); 3851b002d27SArnab Dutta expandedIndices.push_back( 386829446cbSNicolas Vasilache getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); 387829446cbSNicolas Vasilache } 3881b002d27SArnab Dutta return expandedIndices; 3891b002d27SArnab Dutta } 3901b002d27SArnab Dutta 3914dc72d47SNicolas Vasilache template <typename XferOp> 3924dc72d47SNicolas Vasilache static LogicalResult 3934dc72d47SNicolas Vasilache preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, 3944dc72d47SNicolas Vasilache memref::SubViewOp subviewOp) { 3954dc72d47SNicolas Vasilache static_assert( 3964dc72d47SNicolas Vasilache !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value, 3974dc72d47SNicolas Vasilache "must be a vector transfer op"); 3984dc72d47SNicolas Vasilache if (xferOp.hasOutOfBoundsDim()) 3994dc72d47SNicolas Vasilache return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); 4004dc72d47SNicolas Vasilache if (!subviewOp.hasUnitStride()) { 4014dc72d47SNicolas Vasilache return rewriter.notifyMatchFailure( 4024dc72d47SNicolas Vasilache xferOp, "non-1 stride subview, need to track strides in folded memref"); 4034dc72d47SNicolas Vasilache } 4044dc72d47SNicolas Vasilache return success(); 4054dc72d47SNicolas Vasilache } 4064dc72d47SNicolas Vasilache 4074dc72d47SNicolas Vasilache static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, 4084dc72d47SNicolas Vasilache Operation *op, 4094dc72d47SNicolas Vasilache memref::SubViewOp subviewOp) { 4104dc72d47SNicolas Vasilache return success(); 4114dc72d47SNicolas Vasilache } 4124dc72d47SNicolas Vasilache 4134dc72d47SNicolas Vasilache static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, 4144dc72d47SNicolas Vasilache vector::TransferReadOp readOp, 4154dc72d47SNicolas Vasilache memref::SubViewOp subviewOp) { 4164dc72d47SNicolas Vasilache return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp); 4174dc72d47SNicolas Vasilache } 4184dc72d47SNicolas Vasilache 4194dc72d47SNicolas Vasilache static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, 4204dc72d47SNicolas Vasilache vector::TransferWriteOp writeOp, 4214dc72d47SNicolas Vasilache memref::SubViewOp subviewOp) { 4224dc72d47SNicolas Vasilache return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp); 4234dc72d47SNicolas Vasilache } 4244dc72d47SNicolas Vasilache 4251b002d27SArnab Dutta template <typename OpTy> 4261b002d27SArnab Dutta LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite( 4271b002d27SArnab Dutta OpTy loadOp, PatternRewriter &rewriter) const { 4281b002d27SArnab Dutta auto subViewOp = 4291b002d27SArnab Dutta getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 4301b002d27SArnab Dutta 4311b002d27SArnab Dutta if (!subViewOp) 4324dc72d47SNicolas Vasilache return rewriter.notifyMatchFailure(loadOp, "not a subview producer"); 4334dc72d47SNicolas Vasilache 4344dc72d47SNicolas Vasilache LogicalResult preconditionResult = 4354dc72d47SNicolas Vasilache preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp); 4364dc72d47SNicolas Vasilache if (failed(preconditionResult)) 4374dc72d47SNicolas Vasilache return preconditionResult; 4381b002d27SArnab Dutta 4391b002d27SArnab Dutta SmallVector<Value> indices(loadOp.getIndices().begin(), 4401b002d27SArnab Dutta loadOp.getIndices().end()); 4411b002d27SArnab Dutta // For affine ops, we need to apply the map to get the operands to get the 4421b002d27SArnab Dutta // "actual" indices. 4434c48f016SMatthias Springer if (auto affineLoadOp = 4444c48f016SMatthias Springer dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { 4451b002d27SArnab Dutta AffineMap affineMap = affineLoadOp.getAffineMap(); 4461b002d27SArnab Dutta auto expandedIndices = calculateExpandedAccessIndices( 4471b002d27SArnab Dutta affineMap, indices, loadOp.getLoc(), rewriter); 4481b002d27SArnab Dutta indices.assign(expandedIndices.begin(), expandedIndices.end()); 4491b002d27SArnab Dutta } 450203fad47SNicolas Vasilache SmallVector<Value> sourceIndices; 4514c48f016SMatthias Springer affine::resolveIndicesIntoOpWithOffsetsAndStrides( 4524dc72d47SNicolas Vasilache rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(), 4534dc72d47SNicolas Vasilache subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, 4544dc72d47SNicolas Vasilache sourceIndices); 455b7d47ed1SNicolas Vasilache 4561b002d27SArnab Dutta llvm::TypeSwitch<Operation *, void>(loadOp) 4574c48f016SMatthias Springer .Case([&](affine::AffineLoadOp op) { 4584c48f016SMatthias Springer rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( 4594c48f016SMatthias Springer loadOp, subViewOp.getSource(), sourceIndices); 4601b002d27SArnab Dutta }) 4611cb91b42SGuray Ozen .Case([&](memref::LoadOp op) { 4621cb91b42SGuray Ozen rewriter.replaceOpWithNewOp<memref::LoadOp>( 4631cb91b42SGuray Ozen loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal()); 4641cb91b42SGuray Ozen }) 4655ec360c5SGuray Ozen .Case([&](vector::LoadOp op) { 4665ec360c5SGuray Ozen rewriter.replaceOpWithNewOp<vector::LoadOp>( 4675ec360c5SGuray Ozen op, op.getType(), subViewOp.getSource(), sourceIndices); 4685ec360c5SGuray Ozen }) 4695aa2c65aStyb0807 .Case([&](vector::MaskedLoadOp op) { 4705aa2c65aStyb0807 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( 4715aa2c65aStyb0807 op, op.getType(), subViewOp.getSource(), sourceIndices, 4725aa2c65aStyb0807 op.getMask(), op.getPassThru()); 4735aa2c65aStyb0807 }) 4744dc72d47SNicolas Vasilache .Case([&](vector::TransferReadOp op) { 4751b002d27SArnab Dutta rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 4764dc72d47SNicolas Vasilache op, op.getVectorType(), subViewOp.getSource(), sourceIndices, 4774dc72d47SNicolas Vasilache AffineMapAttr::get(expandDimsToRank( 4784dc72d47SNicolas Vasilache op.getPermutationMap(), subViewOp.getSourceType().getRank(), 4794dc72d47SNicolas Vasilache subViewOp.getDroppedDims())), 48048f980c5SQuinn Dawkins op.getPadding(), op.getMask(), op.getInBoundsAttr()); 4811b002d27SArnab Dutta }) 48259e4fbfcSLei Zhang .Case([&](gpu::SubgroupMmaLoadMatrixOp op) { 48359e4fbfcSLei Zhang rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>( 48459e4fbfcSLei Zhang op, op.getType(), subViewOp.getSource(), sourceIndices, 48559e4fbfcSLei Zhang op.getLeadDimension(), op.getTransposeAttr()); 48659e4fbfcSLei Zhang }) 48746c32afbSGuray Ozen .Case([&](nvgpu::LdMatrixOp op) { 48846c32afbSGuray Ozen rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>( 48946c32afbSGuray Ozen op, op.getType(), subViewOp.getSource(), sourceIndices, 49046c32afbSGuray Ozen op.getTranspose(), op.getNumTiles()); 49146c32afbSGuray Ozen }) 4921b002d27SArnab Dutta .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 4931b002d27SArnab Dutta return success(); 4941b002d27SArnab Dutta } 4951b002d27SArnab Dutta 4961b002d27SArnab Dutta template <typename OpTy> 4971b002d27SArnab Dutta LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( 4981b002d27SArnab Dutta OpTy loadOp, PatternRewriter &rewriter) const { 4991b002d27SArnab Dutta auto expandShapeOp = 5001b002d27SArnab Dutta getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>(); 5011b002d27SArnab Dutta 5021b002d27SArnab Dutta if (!expandShapeOp) 5031b002d27SArnab Dutta return failure(); 5041b002d27SArnab Dutta 5051b002d27SArnab Dutta SmallVector<Value> indices(loadOp.getIndices().begin(), 5061b002d27SArnab Dutta loadOp.getIndices().end()); 5071b002d27SArnab Dutta // For affine ops, we need to apply the map to get the operands to get the 5081b002d27SArnab Dutta // "actual" indices. 5094c48f016SMatthias Springer if (auto affineLoadOp = 5104c48f016SMatthias Springer dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { 5111b002d27SArnab Dutta AffineMap affineMap = affineLoadOp.getAffineMap(); 5121b002d27SArnab Dutta auto expandedIndices = calculateExpandedAccessIndices( 5131b002d27SArnab Dutta affineMap, indices, loadOp.getLoc(), rewriter); 5141b002d27SArnab Dutta indices.assign(expandedIndices.begin(), expandedIndices.end()); 5151b002d27SArnab Dutta } 516203fad47SNicolas Vasilache SmallVector<Value> sourceIndices; 5171b002d27SArnab Dutta if (failed(resolveSourceIndicesExpandShape( 5181b002d27SArnab Dutta loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) 5191b002d27SArnab Dutta return failure(); 5201b002d27SArnab Dutta llvm::TypeSwitch<Operation *, void>(loadOp) 52157e43608SKunwar Grover .Case([&](affine::AffineLoadOp op) { 52257e43608SKunwar Grover rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( 5231b002d27SArnab Dutta loadOp, expandShapeOp.getViewSource(), sourceIndices); 5241b002d27SArnab Dutta }) 52557e43608SKunwar Grover .Case([&](memref::LoadOp op) { 52657e43608SKunwar Grover rewriter.replaceOpWithNewOp<memref::LoadOp>( 52757e43608SKunwar Grover loadOp, expandShapeOp.getViewSource(), sourceIndices, 52857e43608SKunwar Grover op.getNontemporal()); 52957e43608SKunwar Grover }) 53057e43608SKunwar Grover .Case([&](vector::LoadOp op) { 53157e43608SKunwar Grover rewriter.replaceOpWithNewOp<vector::LoadOp>( 53257e43608SKunwar Grover op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, 53357e43608SKunwar Grover op.getNontemporal()); 53457e43608SKunwar Grover }) 53557e43608SKunwar Grover .Case([&](vector::MaskedLoadOp op) { 53657e43608SKunwar Grover rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( 53757e43608SKunwar Grover op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, 53857e43608SKunwar Grover op.getMask(), op.getPassThru()); 53957e43608SKunwar Grover }) 5401b002d27SArnab Dutta .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 5411b002d27SArnab Dutta return success(); 5421b002d27SArnab Dutta } 5431b002d27SArnab Dutta 5441b002d27SArnab Dutta template <typename OpTy> 5451b002d27SArnab Dutta LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( 5461b002d27SArnab Dutta OpTy loadOp, PatternRewriter &rewriter) const { 5471b002d27SArnab Dutta auto collapseShapeOp = getMemRefOperand(loadOp) 5481b002d27SArnab Dutta .template getDefiningOp<memref::CollapseShapeOp>(); 5491b002d27SArnab Dutta 5501b002d27SArnab Dutta if (!collapseShapeOp) 5511b002d27SArnab Dutta return failure(); 5521b002d27SArnab Dutta 5531b002d27SArnab Dutta SmallVector<Value> indices(loadOp.getIndices().begin(), 5541b002d27SArnab Dutta loadOp.getIndices().end()); 5551b002d27SArnab Dutta // For affine ops, we need to apply the map to get the operands to get the 5561b002d27SArnab Dutta // "actual" indices. 5574c48f016SMatthias Springer if (auto affineLoadOp = 5584c48f016SMatthias Springer dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { 5591b002d27SArnab Dutta AffineMap affineMap = affineLoadOp.getAffineMap(); 5601b002d27SArnab Dutta auto expandedIndices = calculateExpandedAccessIndices( 5611b002d27SArnab Dutta affineMap, indices, loadOp.getLoc(), rewriter); 5621b002d27SArnab Dutta indices.assign(expandedIndices.begin(), expandedIndices.end()); 5631b002d27SArnab Dutta } 564203fad47SNicolas Vasilache SmallVector<Value> sourceIndices; 5651b002d27SArnab Dutta if (failed(resolveSourceIndicesCollapseShape( 5661b002d27SArnab Dutta loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) 5671b002d27SArnab Dutta return failure(); 5681b002d27SArnab Dutta llvm::TypeSwitch<Operation *, void>(loadOp) 56957e43608SKunwar Grover .Case([&](affine::AffineLoadOp op) { 57057e43608SKunwar Grover rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( 5711b002d27SArnab Dutta loadOp, collapseShapeOp.getViewSource(), sourceIndices); 5721b002d27SArnab Dutta }) 57357e43608SKunwar Grover .Case([&](memref::LoadOp op) { 57457e43608SKunwar Grover rewriter.replaceOpWithNewOp<memref::LoadOp>( 57557e43608SKunwar Grover loadOp, collapseShapeOp.getViewSource(), sourceIndices, 57657e43608SKunwar Grover op.getNontemporal()); 57757e43608SKunwar Grover }) 57857e43608SKunwar Grover .Case([&](vector::LoadOp op) { 57957e43608SKunwar Grover rewriter.replaceOpWithNewOp<vector::LoadOp>( 58057e43608SKunwar Grover op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, 58157e43608SKunwar Grover op.getNontemporal()); 58257e43608SKunwar Grover }) 58357e43608SKunwar Grover .Case([&](vector::MaskedLoadOp op) { 58457e43608SKunwar Grover rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( 58557e43608SKunwar Grover op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, 58657e43608SKunwar Grover op.getMask(), op.getPassThru()); 58757e43608SKunwar Grover }) 5881b002d27SArnab Dutta .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 5891b002d27SArnab Dutta return success(); 5901b002d27SArnab Dutta } 5911b002d27SArnab Dutta 5921b002d27SArnab Dutta template <typename OpTy> 5931b002d27SArnab Dutta LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite( 5941b002d27SArnab Dutta OpTy storeOp, PatternRewriter &rewriter) const { 5951b002d27SArnab Dutta auto subViewOp = 5961b002d27SArnab Dutta getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 5971b002d27SArnab Dutta 5981b002d27SArnab Dutta if (!subViewOp) 5994dc72d47SNicolas Vasilache return rewriter.notifyMatchFailure(storeOp, "not a subview producer"); 6004dc72d47SNicolas Vasilache 6014dc72d47SNicolas Vasilache LogicalResult preconditionResult = 6024dc72d47SNicolas Vasilache preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp); 6034dc72d47SNicolas Vasilache if (failed(preconditionResult)) 6044dc72d47SNicolas Vasilache return preconditionResult; 6051b002d27SArnab Dutta 6061b002d27SArnab Dutta SmallVector<Value> indices(storeOp.getIndices().begin(), 6071b002d27SArnab Dutta storeOp.getIndices().end()); 6081b002d27SArnab Dutta // For affine ops, we need to apply the map to get the operands to get the 6091b002d27SArnab Dutta // "actual" indices. 6104c48f016SMatthias Springer if (auto affineStoreOp = 6114c48f016SMatthias Springer dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { 6121b002d27SArnab Dutta AffineMap affineMap = affineStoreOp.getAffineMap(); 6131b002d27SArnab Dutta auto expandedIndices = calculateExpandedAccessIndices( 6141b002d27SArnab Dutta affineMap, indices, storeOp.getLoc(), rewriter); 6151b002d27SArnab Dutta indices.assign(expandedIndices.begin(), expandedIndices.end()); 6161b002d27SArnab Dutta } 617203fad47SNicolas Vasilache SmallVector<Value> sourceIndices; 6184c48f016SMatthias Springer affine::resolveIndicesIntoOpWithOffsetsAndStrides( 6194dc72d47SNicolas Vasilache rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(), 6204dc72d47SNicolas Vasilache subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, 6214dc72d47SNicolas Vasilache sourceIndices); 622b7d47ed1SNicolas Vasilache 6231b002d27SArnab Dutta llvm::TypeSwitch<Operation *, void>(storeOp) 6244c48f016SMatthias Springer .Case([&](affine::AffineStoreOp op) { 6254c48f016SMatthias Springer rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( 62659e4fbfcSLei Zhang op, op.getValue(), subViewOp.getSource(), sourceIndices); 6271b002d27SArnab Dutta }) 6281cb91b42SGuray Ozen .Case([&](memref::StoreOp op) { 6291cb91b42SGuray Ozen rewriter.replaceOpWithNewOp<memref::StoreOp>( 63059e4fbfcSLei Zhang op, op.getValue(), subViewOp.getSource(), sourceIndices, 6311cb91b42SGuray Ozen op.getNontemporal()); 6321cb91b42SGuray Ozen }) 6331b002d27SArnab Dutta .Case([&](vector::TransferWriteOp op) { 6341b002d27SArnab Dutta rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 635c692a11eSRiver Riddle op, op.getValue(), subViewOp.getSource(), sourceIndices, 6364dc72d47SNicolas Vasilache AffineMapAttr::get(expandDimsToRank( 6374dc72d47SNicolas Vasilache op.getPermutationMap(), subViewOp.getSourceType().getRank(), 6384dc72d47SNicolas Vasilache subViewOp.getDroppedDims())), 63948f980c5SQuinn Dawkins op.getMask(), op.getInBoundsAttr()); 6401b002d27SArnab Dutta }) 641dae3c44cSMax191 .Case([&](vector::StoreOp op) { 642dae3c44cSMax191 rewriter.replaceOpWithNewOp<vector::StoreOp>( 643dae3c44cSMax191 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices); 644dae3c44cSMax191 }) 645dae3c44cSMax191 .Case([&](vector::MaskedStoreOp op) { 646dae3c44cSMax191 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 647dae3c44cSMax191 op, subViewOp.getSource(), sourceIndices, op.getMask(), 648dae3c44cSMax191 op.getValueToStore()); 649dae3c44cSMax191 }) 65059e4fbfcSLei Zhang .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { 65159e4fbfcSLei Zhang rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>( 65259e4fbfcSLei Zhang op, op.getSrc(), subViewOp.getSource(), sourceIndices, 65359e4fbfcSLei Zhang op.getLeadDimension(), op.getTransposeAttr()); 65459e4fbfcSLei Zhang }) 6551b002d27SArnab Dutta .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 6561b002d27SArnab Dutta return success(); 6571b002d27SArnab Dutta } 6581b002d27SArnab Dutta 6591b002d27SArnab Dutta template <typename OpTy> 6601b002d27SArnab Dutta LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( 6611b002d27SArnab Dutta OpTy storeOp, PatternRewriter &rewriter) const { 6621b002d27SArnab Dutta auto expandShapeOp = 6631b002d27SArnab Dutta getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>(); 6641b002d27SArnab Dutta 6651b002d27SArnab Dutta if (!expandShapeOp) 6661b002d27SArnab Dutta return failure(); 6671b002d27SArnab Dutta 6681b002d27SArnab Dutta SmallVector<Value> indices(storeOp.getIndices().begin(), 6691b002d27SArnab Dutta storeOp.getIndices().end()); 6701b002d27SArnab Dutta // For affine ops, we need to apply the map to get the operands to get the 6711b002d27SArnab Dutta // "actual" indices. 6724c48f016SMatthias Springer if (auto affineStoreOp = 6734c48f016SMatthias Springer dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { 6741b002d27SArnab Dutta AffineMap affineMap = affineStoreOp.getAffineMap(); 6751b002d27SArnab Dutta auto expandedIndices = calculateExpandedAccessIndices( 6761b002d27SArnab Dutta affineMap, indices, storeOp.getLoc(), rewriter); 6771b002d27SArnab Dutta indices.assign(expandedIndices.begin(), expandedIndices.end()); 6781b002d27SArnab Dutta } 679203fad47SNicolas Vasilache SmallVector<Value> sourceIndices; 6801b002d27SArnab Dutta if (failed(resolveSourceIndicesExpandShape( 6811b002d27SArnab Dutta storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) 6821b002d27SArnab Dutta return failure(); 6831b002d27SArnab Dutta llvm::TypeSwitch<Operation *, void>(storeOp) 68457e43608SKunwar Grover .Case([&](affine::AffineStoreOp op) { 68557e43608SKunwar Grover rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( 68657e43608SKunwar Grover storeOp, op.getValueToStore(), expandShapeOp.getViewSource(), 6871b002d27SArnab Dutta sourceIndices); 6881b002d27SArnab Dutta }) 68957e43608SKunwar Grover .Case([&](memref::StoreOp op) { 69057e43608SKunwar Grover rewriter.replaceOpWithNewOp<memref::StoreOp>( 69157e43608SKunwar Grover storeOp, op.getValueToStore(), expandShapeOp.getViewSource(), 69257e43608SKunwar Grover sourceIndices, op.getNontemporal()); 69357e43608SKunwar Grover }) 69457e43608SKunwar Grover .Case([&](vector::StoreOp op) { 69557e43608SKunwar Grover rewriter.replaceOpWithNewOp<vector::StoreOp>( 69657e43608SKunwar Grover op, op.getValueToStore(), expandShapeOp.getViewSource(), 69757e43608SKunwar Grover sourceIndices, op.getNontemporal()); 69857e43608SKunwar Grover }) 69957e43608SKunwar Grover .Case([&](vector::MaskedStoreOp op) { 70057e43608SKunwar Grover rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 70157e43608SKunwar Grover op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(), 70257e43608SKunwar Grover op.getValueToStore()); 70357e43608SKunwar Grover }) 7041b002d27SArnab Dutta .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 7051b002d27SArnab Dutta return success(); 7061b002d27SArnab Dutta } 7071b002d27SArnab Dutta 7081b002d27SArnab Dutta template <typename OpTy> 7091b002d27SArnab Dutta LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( 7101b002d27SArnab Dutta OpTy storeOp, PatternRewriter &rewriter) const { 7111b002d27SArnab Dutta auto collapseShapeOp = getMemRefOperand(storeOp) 7121b002d27SArnab Dutta .template getDefiningOp<memref::CollapseShapeOp>(); 7131b002d27SArnab Dutta 7141b002d27SArnab Dutta if (!collapseShapeOp) 7151b002d27SArnab Dutta return failure(); 7161b002d27SArnab Dutta 7171b002d27SArnab Dutta SmallVector<Value> indices(storeOp.getIndices().begin(), 7181b002d27SArnab Dutta storeOp.getIndices().end()); 7191b002d27SArnab Dutta // For affine ops, we need to apply the map to get the operands to get the 7201b002d27SArnab Dutta // "actual" indices. 7214c48f016SMatthias Springer if (auto affineStoreOp = 7224c48f016SMatthias Springer dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { 7231b002d27SArnab Dutta AffineMap affineMap = affineStoreOp.getAffineMap(); 7241b002d27SArnab Dutta auto expandedIndices = calculateExpandedAccessIndices( 7251b002d27SArnab Dutta affineMap, indices, storeOp.getLoc(), rewriter); 7261b002d27SArnab Dutta indices.assign(expandedIndices.begin(), expandedIndices.end()); 7271b002d27SArnab Dutta } 728203fad47SNicolas Vasilache SmallVector<Value> sourceIndices; 7291b002d27SArnab Dutta if (failed(resolveSourceIndicesCollapseShape( 7301b002d27SArnab Dutta storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) 7311b002d27SArnab Dutta return failure(); 7321b002d27SArnab Dutta llvm::TypeSwitch<Operation *, void>(storeOp) 73357e43608SKunwar Grover .Case([&](affine::AffineStoreOp op) { 73457e43608SKunwar Grover rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( 73557e43608SKunwar Grover storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(), 7361b002d27SArnab Dutta sourceIndices); 7371b002d27SArnab Dutta }) 73857e43608SKunwar Grover .Case([&](memref::StoreOp op) { 73957e43608SKunwar Grover rewriter.replaceOpWithNewOp<memref::StoreOp>( 74057e43608SKunwar Grover storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(), 74157e43608SKunwar Grover sourceIndices, op.getNontemporal()); 74257e43608SKunwar Grover }) 74357e43608SKunwar Grover .Case([&](vector::StoreOp op) { 74457e43608SKunwar Grover rewriter.replaceOpWithNewOp<vector::StoreOp>( 74557e43608SKunwar Grover op, op.getValueToStore(), collapseShapeOp.getViewSource(), 74657e43608SKunwar Grover sourceIndices, op.getNontemporal()); 74757e43608SKunwar Grover }) 74857e43608SKunwar Grover .Case([&](vector::MaskedStoreOp op) { 74957e43608SKunwar Grover rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 75057e43608SKunwar Grover op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(), 75157e43608SKunwar Grover op.getValueToStore()); 75257e43608SKunwar Grover }) 7531b002d27SArnab Dutta .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 7541b002d27SArnab Dutta return success(); 7551b002d27SArnab Dutta } 7561b002d27SArnab Dutta 757baa5beecStyb0807 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite( 758fc5c1a76SManish Gupta nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const { 759fc5c1a76SManish Gupta 760fc5c1a76SManish Gupta LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n"); 761fc5c1a76SManish Gupta 762fc5c1a76SManish Gupta auto srcSubViewOp = 763fc5c1a76SManish Gupta copyOp.getSrc().template getDefiningOp<memref::SubViewOp>(); 764fc5c1a76SManish Gupta auto dstSubViewOp = 765fc5c1a76SManish Gupta copyOp.getDst().template getDefiningOp<memref::SubViewOp>(); 766fc5c1a76SManish Gupta 767fc5c1a76SManish Gupta if (!(srcSubViewOp || dstSubViewOp)) 768fc5c1a76SManish Gupta return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for " 769fc5c1a76SManish Gupta "source or destination"); 770fc5c1a76SManish Gupta 771fc5c1a76SManish Gupta // If the source is a subview, we need to resolve the indices. 772fc5c1a76SManish Gupta SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(), 773fc5c1a76SManish Gupta copyOp.getSrcIndices().end()); 774fc5c1a76SManish Gupta SmallVector<Value> foldedSrcIndices(srcindices); 775fc5c1a76SManish Gupta 776fc5c1a76SManish Gupta if (srcSubViewOp) { 777fc5c1a76SManish Gupta LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n"); 7784c48f016SMatthias Springer affine::resolveIndicesIntoOpWithOffsetsAndStrides( 779fc5c1a76SManish Gupta rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(), 780fc5c1a76SManish Gupta srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(), 781fc5c1a76SManish Gupta srcindices, foldedSrcIndices); 782fc5c1a76SManish Gupta } 783fc5c1a76SManish Gupta 784fc5c1a76SManish Gupta // If the destination is a subview, we need to resolve the indices. 785fc5c1a76SManish Gupta SmallVector<Value> dstindices(copyOp.getDstIndices().begin(), 786fc5c1a76SManish Gupta copyOp.getDstIndices().end()); 787fc5c1a76SManish Gupta SmallVector<Value> foldedDstIndices(dstindices); 788fc5c1a76SManish Gupta 789fc5c1a76SManish Gupta if (dstSubViewOp) { 790fc5c1a76SManish Gupta LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n"); 7914c48f016SMatthias Springer affine::resolveIndicesIntoOpWithOffsetsAndStrides( 792fc5c1a76SManish Gupta rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(), 793fc5c1a76SManish Gupta dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(), 794fc5c1a76SManish Gupta dstindices, foldedDstIndices); 795fc5c1a76SManish Gupta } 796fc5c1a76SManish Gupta 797fc5c1a76SManish Gupta // Replace the copy op with a new copy op that uses the source and destination 798fc5c1a76SManish Gupta // of the subview. 799fc5c1a76SManish Gupta rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>( 800fc5c1a76SManish Gupta copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()), 801fc5c1a76SManish Gupta (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()), 802fc5c1a76SManish Gupta foldedDstIndices, 803fc5c1a76SManish Gupta (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()), 804fc5c1a76SManish Gupta foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(), 805fc5c1a76SManish Gupta copyOp.getBypassL1Attr()); 806fc5c1a76SManish Gupta 807fc5c1a76SManish Gupta return success(); 808fc5c1a76SManish Gupta } 809fc5c1a76SManish Gupta 8101b002d27SArnab Dutta void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { 8114c48f016SMatthias Springer patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>, 8121b002d27SArnab Dutta LoadOpOfSubViewOpFolder<memref::LoadOp>, 81346c32afbSGuray Ozen LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>, 8145ec360c5SGuray Ozen LoadOpOfSubViewOpFolder<vector::LoadOp>, 8155aa2c65aStyb0807 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>, 8161b002d27SArnab Dutta LoadOpOfSubViewOpFolder<vector::TransferReadOp>, 81759e4fbfcSLei Zhang LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>, 8184c48f016SMatthias Springer StoreOpOfSubViewOpFolder<affine::AffineStoreOp>, 8191b002d27SArnab Dutta StoreOpOfSubViewOpFolder<memref::StoreOp>, 8201b002d27SArnab Dutta StoreOpOfSubViewOpFolder<vector::TransferWriteOp>, 821dae3c44cSMax191 StoreOpOfSubViewOpFolder<vector::StoreOp>, 822dae3c44cSMax191 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>, 82359e4fbfcSLei Zhang StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>, 8244c48f016SMatthias Springer LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>, 8251b002d27SArnab Dutta LoadOpOfExpandShapeOpFolder<memref::LoadOp>, 82657e43608SKunwar Grover LoadOpOfExpandShapeOpFolder<vector::LoadOp>, 82757e43608SKunwar Grover LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>, 8284c48f016SMatthias Springer StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>, 8291b002d27SArnab Dutta StoreOpOfExpandShapeOpFolder<memref::StoreOp>, 83057e43608SKunwar Grover StoreOpOfExpandShapeOpFolder<vector::StoreOp>, 83157e43608SKunwar Grover StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>, 8324c48f016SMatthias Springer LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>, 8331b002d27SArnab Dutta LoadOpOfCollapseShapeOpFolder<memref::LoadOp>, 83457e43608SKunwar Grover LoadOpOfCollapseShapeOpFolder<vector::LoadOp>, 83557e43608SKunwar Grover LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>, 8364c48f016SMatthias Springer StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>, 837ccb8a4e3SMatthias Springer StoreOpOfCollapseShapeOpFolder<memref::StoreOp>, 83857e43608SKunwar Grover StoreOpOfCollapseShapeOpFolder<vector::StoreOp>, 83957e43608SKunwar Grover StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>, 840baa5beecStyb0807 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>( 841fc5c1a76SManish Gupta patterns.getContext()); 8421b002d27SArnab Dutta } 8431b002d27SArnab Dutta 8441b002d27SArnab Dutta //===----------------------------------------------------------------------===// 8451b002d27SArnab Dutta // Pass registration 8461b002d27SArnab Dutta //===----------------------------------------------------------------------===// 8471b002d27SArnab Dutta 8481b002d27SArnab Dutta namespace { 8491b002d27SArnab Dutta 8501b002d27SArnab Dutta struct FoldMemRefAliasOpsPass final 85167d0d7acSMichele Scuttari : public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> { 8521b002d27SArnab Dutta void runOnOperation() override; 8531b002d27SArnab Dutta }; 8541b002d27SArnab Dutta 8551b002d27SArnab Dutta } // namespace 8561b002d27SArnab Dutta 8571b002d27SArnab Dutta void FoldMemRefAliasOpsPass::runOnOperation() { 8581b002d27SArnab Dutta RewritePatternSet patterns(&getContext()); 8591b002d27SArnab Dutta memref::populateFoldMemRefAliasOpPatterns(patterns); 860*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 8611b002d27SArnab Dutta } 8621b002d27SArnab Dutta 8631b002d27SArnab Dutta std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() { 8641b002d27SArnab Dutta return std::make_unique<FoldMemRefAliasOpsPass>(); 8651b002d27SArnab Dutta } 866