14dc72d47SNicolas Vasilache //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===// 24dc72d47SNicolas Vasilache // 34dc72d47SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44dc72d47SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information. 54dc72d47SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64dc72d47SNicolas Vasilache // 74dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===// 84dc72d47SNicolas Vasilache // 94dc72d47SNicolas Vasilache // Fold tensor subset ops with producer / consumers. 104dc72d47SNicolas Vasilache // 114dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===// 124dc72d47SNicolas Vasilache 134dc72d47SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h" 144dc72d47SNicolas Vasilache #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" 1533468a51SNicolas Vasilache #include "mlir/Dialect/SCF/IR/SCF.h" 164dc72d47SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h" 174dc72d47SNicolas Vasilache #include "mlir/Dialect/Tensor/Transforms/Passes.h" 184dc72d47SNicolas Vasilache #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 194dc72d47SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h" 204dc72d47SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h" 211ede5039SHugo Trachino #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 224dc72d47SNicolas Vasilache #include "mlir/IR/AffineMap.h" 234dc72d47SNicolas Vasilache #include "mlir/IR/BuiltinAttributes.h" 24760ffa47SRajveer Singh Bharadwaj #include "mlir/Interfaces/ValueBoundsOpInterface.h" 254dc72d47SNicolas Vasilache #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 264dc72d47SNicolas Vasilache #include "llvm/ADT/TypeSwitch.h" 2733468a51SNicolas Vasilache #include <type_traits> 284dc72d47SNicolas Vasilache 294dc72d47SNicolas Vasilache namespace mlir { 304dc72d47SNicolas Vasilache namespace tensor { 314dc72d47SNicolas Vasilache #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS 324dc72d47SNicolas Vasilache #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" 334dc72d47SNicolas Vasilache } // namespace tensor 344dc72d47SNicolas Vasilache } // namespace mlir 354dc72d47SNicolas Vasilache 364dc72d47SNicolas Vasilache using namespace mlir; 374dc72d47SNicolas Vasilache 384dc72d47SNicolas Vasilache static Value getTensorOperand(vector::TransferReadOp op) { 394dc72d47SNicolas Vasilache return op.getSource(); 404dc72d47SNicolas Vasilache } 414dc72d47SNicolas Vasilache 424dc72d47SNicolas Vasilache static Value getTensorOperand(tensor::InsertSliceOp op) { 434dc72d47SNicolas Vasilache return op.getSource(); 444dc72d47SNicolas Vasilache } 454dc72d47SNicolas Vasilache 464dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===// 474dc72d47SNicolas Vasilache // Patterns 484dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===// 494dc72d47SNicolas Vasilache 504dc72d47SNicolas Vasilache namespace { 514dc72d47SNicolas Vasilache /// Merge extract_slice operation with load/transferRead operation. 524dc72d47SNicolas Vasilache class TransferReadOfExtractSliceOpFolder final 531ede5039SHugo Trachino : public vector::MaskableOpRewritePattern<vector::TransferReadOp> { 544dc72d47SNicolas Vasilache public: 551ede5039SHugo Trachino using MaskableOpRewritePattern::MaskableOpRewritePattern; 564dc72d47SNicolas Vasilache 571ede5039SHugo Trachino FailureOr<mlir::Value> 581ede5039SHugo Trachino matchAndRewriteMaskableOp(vector::TransferReadOp readOp, 591ede5039SHugo Trachino vector::MaskingOpInterface maskOp, 604dc72d47SNicolas Vasilache PatternRewriter &rewriter) const override; 614dc72d47SNicolas Vasilache }; 624dc72d47SNicolas Vasilache 634dc72d47SNicolas Vasilache /// Merge insert_slice operation with store/transferWriteOp operation. 644dc72d47SNicolas Vasilache class InsertSliceOfTransferWriteOpFolder final 654dc72d47SNicolas Vasilache : public OpRewritePattern<tensor::InsertSliceOp> { 664dc72d47SNicolas Vasilache public: 674dc72d47SNicolas Vasilache using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; 684dc72d47SNicolas Vasilache 694dc72d47SNicolas Vasilache LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp, 704dc72d47SNicolas Vasilache PatternRewriter &rewriter) const override; 71760ffa47SRajveer Singh Bharadwaj 72760ffa47SRajveer Singh Bharadwaj private: 73760ffa47SRajveer Singh Bharadwaj static bool 74760ffa47SRajveer Singh Bharadwaj doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp); 754dc72d47SNicolas Vasilache }; 764dc72d47SNicolas Vasilache } // namespace 774dc72d47SNicolas Vasilache 784dc72d47SNicolas Vasilache template <typename XferOp, typename ExtractOrInsertOp> 794dc72d47SNicolas Vasilache static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp( 804dc72d47SNicolas Vasilache RewriterBase &rewriter, XferOp xferOp, 814dc72d47SNicolas Vasilache ExtractOrInsertOp extractOrInsertSliceOp) { 824dc72d47SNicolas Vasilache if (xferOp.hasOutOfBoundsDim()) 834dc72d47SNicolas Vasilache return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); 844dc72d47SNicolas Vasilache if (xferOp.getMask()) 854dc72d47SNicolas Vasilache return rewriter.notifyMatchFailure(xferOp, "masked transfer"); 864dc72d47SNicolas Vasilache if (!extractOrInsertSliceOp.hasUnitStride()) { 874dc72d47SNicolas Vasilache return rewriter.notifyMatchFailure( 884dc72d47SNicolas Vasilache xferOp, "non-1 stride insert/extract, requires keeping track of " 894dc72d47SNicolas Vasilache "strides, this may result in needing to insert " 904dc72d47SNicolas Vasilache "vector.insert_strided_slice/extract_strided_slice ops"); 914dc72d47SNicolas Vasilache } 924dc72d47SNicolas Vasilache return success(); 934dc72d47SNicolas Vasilache } 944dc72d47SNicolas Vasilache 951ede5039SHugo Trachino FailureOr<mlir::Value> 961ede5039SHugo Trachino TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp( 971ede5039SHugo Trachino vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp, 981ede5039SHugo Trachino PatternRewriter &rewriter) const { 994dc72d47SNicolas Vasilache auto extractSliceOp = 1004dc72d47SNicolas Vasilache getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>(); 1014dc72d47SNicolas Vasilache if (!extractSliceOp) 1024dc72d47SNicolas Vasilache return rewriter.notifyMatchFailure(readOp, "not an extract_slice"); 1034dc72d47SNicolas Vasilache 1044dc72d47SNicolas Vasilache LogicalResult preconditionResult = 1054dc72d47SNicolas Vasilache preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp, 1064dc72d47SNicolas Vasilache extractSliceOp); 1074dc72d47SNicolas Vasilache if (failed(preconditionResult)) 1081ede5039SHugo Trachino return rewriter.notifyMatchFailure(readOp, "Failed preconditions"); 1094dc72d47SNicolas Vasilache 1104dc72d47SNicolas Vasilache SmallVector<Value> indices(readOp.getIndices().begin(), 1114dc72d47SNicolas Vasilache readOp.getIndices().end()); 1124dc72d47SNicolas Vasilache SmallVector<Value> sourceIndices; 1134c48f016SMatthias Springer affine::resolveIndicesIntoOpWithOffsetsAndStrides( 1144dc72d47SNicolas Vasilache rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(), 1154dc72d47SNicolas Vasilache extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(), 1164dc72d47SNicolas Vasilache indices, sourceIndices); 1174dc72d47SNicolas Vasilache 1181ede5039SHugo Trachino Operation *newOp = rewriter.create<vector::TransferReadOp>( 1191ede5039SHugo Trachino readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(), 1201ede5039SHugo Trachino sourceIndices, 1214dc72d47SNicolas Vasilache AffineMapAttr::get(expandDimsToRank( 1224dc72d47SNicolas Vasilache readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(), 1234dc72d47SNicolas Vasilache extractSliceOp.getDroppedDims())), 1244dc72d47SNicolas Vasilache readOp.getPadding(), 1254dc72d47SNicolas Vasilache /*mask=*/Value(), readOp.getInBoundsAttr()); 1261ede5039SHugo Trachino if (maskOp) 1271ede5039SHugo Trachino newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask()); 1281ede5039SHugo Trachino return newOp->getResults()[0]; 1294dc72d47SNicolas Vasilache } 1304dc72d47SNicolas Vasilache 1314dc72d47SNicolas Vasilache LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite( 1324dc72d47SNicolas Vasilache tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const { 1334dc72d47SNicolas Vasilache auto writeOp = getTensorOperand(insertSliceOp) 1344dc72d47SNicolas Vasilache .template getDefiningOp<vector::TransferWriteOp>(); 1354dc72d47SNicolas Vasilache if (!writeOp) 1364dc72d47SNicolas Vasilache return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write"); 1374dc72d47SNicolas Vasilache 1384dc72d47SNicolas Vasilache LogicalResult preconditionResult = 1394dc72d47SNicolas Vasilache preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp, 1404dc72d47SNicolas Vasilache insertSliceOp); 1414dc72d47SNicolas Vasilache if (failed(preconditionResult)) 1424dc72d47SNicolas Vasilache return preconditionResult; 1434dc72d47SNicolas Vasilache 144760ffa47SRajveer Singh Bharadwaj if (!doesTransferWriteCoverInsertSlice(writeOp)) 145760ffa47SRajveer Singh Bharadwaj return rewriter.notifyMatchFailure( 146760ffa47SRajveer Singh Bharadwaj insertSliceOp, "transfer_write does not cover insert_slice"); 147760ffa47SRajveer Singh Bharadwaj 1484dc72d47SNicolas Vasilache SmallVector<Value> indices(writeOp.getIndices().begin(), 1494dc72d47SNicolas Vasilache writeOp.getIndices().end()); 1504dc72d47SNicolas Vasilache SmallVector<Value> sourceIndices; 1514c48f016SMatthias Springer affine::resolveIndicesIntoOpWithOffsetsAndStrides( 1524dc72d47SNicolas Vasilache rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(), 1534dc72d47SNicolas Vasilache insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices, 1544dc72d47SNicolas Vasilache sourceIndices); 1554dc72d47SNicolas Vasilache 1564dc72d47SNicolas Vasilache rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 1574dc72d47SNicolas Vasilache insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices, 1584dc72d47SNicolas Vasilache AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(), 1594dc72d47SNicolas Vasilache insertSliceOp.getDestType().getRank(), 1604dc72d47SNicolas Vasilache insertSliceOp.getDroppedDims())), 1614dc72d47SNicolas Vasilache writeOp.getInBoundsAttr()); 1624dc72d47SNicolas Vasilache 1634dc72d47SNicolas Vasilache return success(); 1644dc72d47SNicolas Vasilache } 1654dc72d47SNicolas Vasilache 166760ffa47SRajveer Singh Bharadwaj bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice( 167760ffa47SRajveer Singh Bharadwaj vector::TransferWriteOp writeOp) { 168760ffa47SRajveer Singh Bharadwaj if (writeOp.getShapedType().hasStaticShape()) 169760ffa47SRajveer Singh Bharadwaj return llvm::equal(writeOp.getVectorType().getShape(), 170760ffa47SRajveer Singh Bharadwaj writeOp.getShapedType().getShape()); 171760ffa47SRajveer Singh Bharadwaj 172760ffa47SRajveer Singh Bharadwaj // TODO: Use ValueBoundsConstraintSet for dynamic shapes. 173760ffa47SRajveer Singh Bharadwaj 174760ffa47SRajveer Singh Bharadwaj return false; 175760ffa47SRajveer Singh Bharadwaj } 176760ffa47SRajveer Singh Bharadwaj 17733468a51SNicolas Vasilache template <typename OpTy> 17833468a51SNicolas Vasilache struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> { 17933468a51SNicolas Vasilache using OpRewritePattern<OpTy>::OpRewritePattern; 18033468a51SNicolas Vasilache 18133468a51SNicolas Vasilache LogicalResult matchAndRewrite(OpTy insertSliceOp, 18233468a51SNicolas Vasilache PatternRewriter &rewriter) const override { 18333468a51SNicolas Vasilache auto sourceInsertSliceOp = 18433468a51SNicolas Vasilache insertSliceOp.getSource() 18533468a51SNicolas Vasilache .template getDefiningOp<tensor::InsertSliceOp>(); 18633468a51SNicolas Vasilache if (!sourceInsertSliceOp) 18733468a51SNicolas Vasilache return failure(); 18833468a51SNicolas Vasilache 18933468a51SNicolas Vasilache // TODO: relax unit stride assumption where possible. 19033468a51SNicolas Vasilache if (!insertSliceOp.hasUnitStride()) { 19133468a51SNicolas Vasilache return rewriter.notifyMatchFailure(insertSliceOp, 19233468a51SNicolas Vasilache "requires unit strides"); 19333468a51SNicolas Vasilache } 19433468a51SNicolas Vasilache if (!sourceInsertSliceOp.hasUnitStride()) { 19533468a51SNicolas Vasilache return rewriter.notifyMatchFailure(sourceInsertSliceOp, 19633468a51SNicolas Vasilache "requires unit strides"); 19733468a51SNicolas Vasilache } 19833468a51SNicolas Vasilache 19933468a51SNicolas Vasilache int64_t srcDim = 0; 20033468a51SNicolas Vasilache llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims(); 20133468a51SNicolas Vasilache for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) { 20233468a51SNicolas Vasilache if (droppedDims[d]) 20333468a51SNicolas Vasilache continue; 20433468a51SNicolas Vasilache if (insertSliceOp.getMixedSizes()[d] != 20533468a51SNicolas Vasilache sourceInsertSliceOp.getMixedSizes()[srcDim++]) { 20633468a51SNicolas Vasilache return rewriter.notifyMatchFailure( 20733468a51SNicolas Vasilache sourceInsertSliceOp, 20833468a51SNicolas Vasilache "requires matching sizes to fold, otherwise a copy is needed"); 20933468a51SNicolas Vasilache } 21033468a51SNicolas Vasilache } 21133468a51SNicolas Vasilache 21233468a51SNicolas Vasilache // Resolve sizes according to dropped dims. 21333468a51SNicolas Vasilache SmallVector<OpFoldResult> resolvedSizes; 21433468a51SNicolas Vasilache // Note: the "insertSlice" case is symmetrical to the extract/subview case: 21533468a51SNicolas Vasilache // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is 21633468a51SNicolas Vasilache // passed as the destination to the helper function. 2174c48f016SMatthias Springer affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(), 21833468a51SNicolas Vasilache sourceInsertSliceOp.getMixedSizes(), 21933468a51SNicolas Vasilache droppedDims, resolvedSizes); 22033468a51SNicolas Vasilache 22133468a51SNicolas Vasilache // If we are inside an InParallel region, temporarily set the insertion 22233468a51SNicolas Vasilache // point outside: only tensor.parallel_insert_slice ops are allowed in 22333468a51SNicolas Vasilache // there. 22433468a51SNicolas Vasilache if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) { 22533468a51SNicolas Vasilache rewriter.setInsertionPoint( 22633468a51SNicolas Vasilache insertSliceOp->template getParentOfType<scf::InParallelOp>()); 22733468a51SNicolas Vasilache } 22833468a51SNicolas Vasilache 22933468a51SNicolas Vasilache // Resolve offsets according to source offsets and strides. 23033468a51SNicolas Vasilache SmallVector<Value> resolvedOffsets; 23133468a51SNicolas Vasilache // Note: the "insertSlice" case is symmetrical to the extract/subview case: 23233468a51SNicolas Vasilache // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is 23333468a51SNicolas Vasilache // passed as the destination to the helper function. 2344c48f016SMatthias Springer affine::resolveIndicesIntoOpWithOffsetsAndStrides( 23533468a51SNicolas Vasilache rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(), 23633468a51SNicolas Vasilache insertSliceOp.getMixedStrides(), droppedDims, 23733468a51SNicolas Vasilache sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets); 23833468a51SNicolas Vasilache 23933468a51SNicolas Vasilache // Reset the insertion point. 24033468a51SNicolas Vasilache rewriter.setInsertionPoint(insertSliceOp); 24133468a51SNicolas Vasilache // Replace original op. 24233468a51SNicolas Vasilache rewriter.replaceOpWithNewOp<OpTy>( 24333468a51SNicolas Vasilache insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(), 24433468a51SNicolas Vasilache getAsOpFoldResult(resolvedOffsets), resolvedSizes, 24533468a51SNicolas Vasilache insertSliceOp.getMixedStrides()); 24633468a51SNicolas Vasilache 24733468a51SNicolas Vasilache return success(); 24833468a51SNicolas Vasilache } 24933468a51SNicolas Vasilache }; 25033468a51SNicolas Vasilache 2514dc72d47SNicolas Vasilache void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { 252867afe5eSMatthias Springer populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); 253867afe5eSMatthias Springer patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>, 25433468a51SNicolas Vasilache InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>( 25533468a51SNicolas Vasilache patterns.getContext()); 2564dc72d47SNicolas Vasilache } 257867afe5eSMatthias Springer 258867afe5eSMatthias Springer void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns( 259867afe5eSMatthias Springer RewritePatternSet &patterns) { 260867afe5eSMatthias Springer patterns.add<TransferReadOfExtractSliceOpFolder, 261867afe5eSMatthias Springer InsertSliceOfTransferWriteOpFolder>(patterns.getContext()); 262867afe5eSMatthias Springer } 263867afe5eSMatthias Springer 2644dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===// 2654dc72d47SNicolas Vasilache // Pass registration 2664dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===// 2674dc72d47SNicolas Vasilache 2684dc72d47SNicolas Vasilache namespace { 2694dc72d47SNicolas Vasilache 2704dc72d47SNicolas Vasilache struct FoldTensorSubsetOpsPass final 2714dc72d47SNicolas Vasilache : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> { 2724dc72d47SNicolas Vasilache void runOnOperation() override; 2734dc72d47SNicolas Vasilache }; 2744dc72d47SNicolas Vasilache 2754dc72d47SNicolas Vasilache } // namespace 2764dc72d47SNicolas Vasilache 2774dc72d47SNicolas Vasilache void FoldTensorSubsetOpsPass::runOnOperation() { 2784dc72d47SNicolas Vasilache RewritePatternSet patterns(&getContext()); 2794dc72d47SNicolas Vasilache tensor::populateFoldTensorSubsetOpPatterns(patterns); 280*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 2814dc72d47SNicolas Vasilache } 2824dc72d47SNicolas Vasilache 2834dc72d47SNicolas Vasilache std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() { 2844dc72d47SNicolas Vasilache return std::make_unique<FoldTensorSubsetOpsPass>(); 2854dc72d47SNicolas Vasilache } 286