xref: /llvm-project/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
14dc72d47SNicolas Vasilache //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
24dc72d47SNicolas Vasilache //
34dc72d47SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44dc72d47SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
54dc72d47SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64dc72d47SNicolas Vasilache //
74dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===//
84dc72d47SNicolas Vasilache //
94dc72d47SNicolas Vasilache // Fold tensor subset ops with producer / consumers.
104dc72d47SNicolas Vasilache //
114dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===//
124dc72d47SNicolas Vasilache 
134dc72d47SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
144dc72d47SNicolas Vasilache #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
1533468a51SNicolas Vasilache #include "mlir/Dialect/SCF/IR/SCF.h"
164dc72d47SNicolas Vasilache #include "mlir/Dialect/Tensor/IR/Tensor.h"
174dc72d47SNicolas Vasilache #include "mlir/Dialect/Tensor/Transforms/Passes.h"
184dc72d47SNicolas Vasilache #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
194dc72d47SNicolas Vasilache #include "mlir/Dialect/Utils/IndexingUtils.h"
204dc72d47SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h"
211ede5039SHugo Trachino #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
224dc72d47SNicolas Vasilache #include "mlir/IR/AffineMap.h"
234dc72d47SNicolas Vasilache #include "mlir/IR/BuiltinAttributes.h"
24760ffa47SRajveer Singh Bharadwaj #include "mlir/Interfaces/ValueBoundsOpInterface.h"
254dc72d47SNicolas Vasilache #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
264dc72d47SNicolas Vasilache #include "llvm/ADT/TypeSwitch.h"
2733468a51SNicolas Vasilache #include <type_traits>
284dc72d47SNicolas Vasilache 
294dc72d47SNicolas Vasilache namespace mlir {
304dc72d47SNicolas Vasilache namespace tensor {
314dc72d47SNicolas Vasilache #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
324dc72d47SNicolas Vasilache #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
334dc72d47SNicolas Vasilache } // namespace tensor
344dc72d47SNicolas Vasilache } // namespace mlir
354dc72d47SNicolas Vasilache 
364dc72d47SNicolas Vasilache using namespace mlir;
374dc72d47SNicolas Vasilache 
384dc72d47SNicolas Vasilache static Value getTensorOperand(vector::TransferReadOp op) {
394dc72d47SNicolas Vasilache   return op.getSource();
404dc72d47SNicolas Vasilache }
414dc72d47SNicolas Vasilache 
424dc72d47SNicolas Vasilache static Value getTensorOperand(tensor::InsertSliceOp op) {
434dc72d47SNicolas Vasilache   return op.getSource();
444dc72d47SNicolas Vasilache }
454dc72d47SNicolas Vasilache 
464dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===//
474dc72d47SNicolas Vasilache // Patterns
484dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===//
494dc72d47SNicolas Vasilache 
504dc72d47SNicolas Vasilache namespace {
514dc72d47SNicolas Vasilache /// Merge extract_slice operation with load/transferRead operation.
524dc72d47SNicolas Vasilache class TransferReadOfExtractSliceOpFolder final
531ede5039SHugo Trachino     : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
544dc72d47SNicolas Vasilache public:
551ede5039SHugo Trachino   using MaskableOpRewritePattern::MaskableOpRewritePattern;
564dc72d47SNicolas Vasilache 
571ede5039SHugo Trachino   FailureOr<mlir::Value>
581ede5039SHugo Trachino   matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
591ede5039SHugo Trachino                             vector::MaskingOpInterface maskOp,
604dc72d47SNicolas Vasilache                             PatternRewriter &rewriter) const override;
614dc72d47SNicolas Vasilache };
624dc72d47SNicolas Vasilache 
634dc72d47SNicolas Vasilache /// Merge insert_slice operation with store/transferWriteOp operation.
644dc72d47SNicolas Vasilache class InsertSliceOfTransferWriteOpFolder final
654dc72d47SNicolas Vasilache     : public OpRewritePattern<tensor::InsertSliceOp> {
664dc72d47SNicolas Vasilache public:
674dc72d47SNicolas Vasilache   using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
684dc72d47SNicolas Vasilache 
694dc72d47SNicolas Vasilache   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
704dc72d47SNicolas Vasilache                                 PatternRewriter &rewriter) const override;
71760ffa47SRajveer Singh Bharadwaj 
72760ffa47SRajveer Singh Bharadwaj private:
73760ffa47SRajveer Singh Bharadwaj   static bool
74760ffa47SRajveer Singh Bharadwaj   doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
754dc72d47SNicolas Vasilache };
764dc72d47SNicolas Vasilache } // namespace
774dc72d47SNicolas Vasilache 
784dc72d47SNicolas Vasilache template <typename XferOp, typename ExtractOrInsertOp>
794dc72d47SNicolas Vasilache static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
804dc72d47SNicolas Vasilache     RewriterBase &rewriter, XferOp xferOp,
814dc72d47SNicolas Vasilache     ExtractOrInsertOp extractOrInsertSliceOp) {
824dc72d47SNicolas Vasilache   if (xferOp.hasOutOfBoundsDim())
834dc72d47SNicolas Vasilache     return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
844dc72d47SNicolas Vasilache   if (xferOp.getMask())
854dc72d47SNicolas Vasilache     return rewriter.notifyMatchFailure(xferOp, "masked transfer");
864dc72d47SNicolas Vasilache   if (!extractOrInsertSliceOp.hasUnitStride()) {
874dc72d47SNicolas Vasilache     return rewriter.notifyMatchFailure(
884dc72d47SNicolas Vasilache         xferOp, "non-1 stride insert/extract, requires keeping track of "
894dc72d47SNicolas Vasilache                 "strides, this may result in needing to insert "
904dc72d47SNicolas Vasilache                 "vector.insert_strided_slice/extract_strided_slice ops");
914dc72d47SNicolas Vasilache   }
924dc72d47SNicolas Vasilache   return success();
934dc72d47SNicolas Vasilache }
944dc72d47SNicolas Vasilache 
951ede5039SHugo Trachino FailureOr<mlir::Value>
961ede5039SHugo Trachino TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
971ede5039SHugo Trachino     vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
981ede5039SHugo Trachino     PatternRewriter &rewriter) const {
994dc72d47SNicolas Vasilache   auto extractSliceOp =
1004dc72d47SNicolas Vasilache       getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
1014dc72d47SNicolas Vasilache   if (!extractSliceOp)
1024dc72d47SNicolas Vasilache     return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
1034dc72d47SNicolas Vasilache 
1044dc72d47SNicolas Vasilache   LogicalResult preconditionResult =
1054dc72d47SNicolas Vasilache       preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
1064dc72d47SNicolas Vasilache                                                      extractSliceOp);
1074dc72d47SNicolas Vasilache   if (failed(preconditionResult))
1081ede5039SHugo Trachino     return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
1094dc72d47SNicolas Vasilache 
1104dc72d47SNicolas Vasilache   SmallVector<Value> indices(readOp.getIndices().begin(),
1114dc72d47SNicolas Vasilache                              readOp.getIndices().end());
1124dc72d47SNicolas Vasilache   SmallVector<Value> sourceIndices;
1134c48f016SMatthias Springer   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
1144dc72d47SNicolas Vasilache       rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
1154dc72d47SNicolas Vasilache       extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
1164dc72d47SNicolas Vasilache       indices, sourceIndices);
1174dc72d47SNicolas Vasilache 
1181ede5039SHugo Trachino   Operation *newOp = rewriter.create<vector::TransferReadOp>(
1191ede5039SHugo Trachino       readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
1201ede5039SHugo Trachino       sourceIndices,
1214dc72d47SNicolas Vasilache       AffineMapAttr::get(expandDimsToRank(
1224dc72d47SNicolas Vasilache           readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
1234dc72d47SNicolas Vasilache           extractSliceOp.getDroppedDims())),
1244dc72d47SNicolas Vasilache       readOp.getPadding(),
1254dc72d47SNicolas Vasilache       /*mask=*/Value(), readOp.getInBoundsAttr());
1261ede5039SHugo Trachino   if (maskOp)
1271ede5039SHugo Trachino     newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
1281ede5039SHugo Trachino   return newOp->getResults()[0];
1294dc72d47SNicolas Vasilache }
1304dc72d47SNicolas Vasilache 
1314dc72d47SNicolas Vasilache LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
1324dc72d47SNicolas Vasilache     tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
1334dc72d47SNicolas Vasilache   auto writeOp = getTensorOperand(insertSliceOp)
1344dc72d47SNicolas Vasilache                      .template getDefiningOp<vector::TransferWriteOp>();
1354dc72d47SNicolas Vasilache   if (!writeOp)
1364dc72d47SNicolas Vasilache     return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
1374dc72d47SNicolas Vasilache 
1384dc72d47SNicolas Vasilache   LogicalResult preconditionResult =
1394dc72d47SNicolas Vasilache       preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
1404dc72d47SNicolas Vasilache                                                      insertSliceOp);
1414dc72d47SNicolas Vasilache   if (failed(preconditionResult))
1424dc72d47SNicolas Vasilache     return preconditionResult;
1434dc72d47SNicolas Vasilache 
144760ffa47SRajveer Singh Bharadwaj   if (!doesTransferWriteCoverInsertSlice(writeOp))
145760ffa47SRajveer Singh Bharadwaj     return rewriter.notifyMatchFailure(
146760ffa47SRajveer Singh Bharadwaj         insertSliceOp, "transfer_write does not cover insert_slice");
147760ffa47SRajveer Singh Bharadwaj 
1484dc72d47SNicolas Vasilache   SmallVector<Value> indices(writeOp.getIndices().begin(),
1494dc72d47SNicolas Vasilache                              writeOp.getIndices().end());
1504dc72d47SNicolas Vasilache   SmallVector<Value> sourceIndices;
1514c48f016SMatthias Springer   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
1524dc72d47SNicolas Vasilache       rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
1534dc72d47SNicolas Vasilache       insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
1544dc72d47SNicolas Vasilache       sourceIndices);
1554dc72d47SNicolas Vasilache 
1564dc72d47SNicolas Vasilache   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1574dc72d47SNicolas Vasilache       insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
1584dc72d47SNicolas Vasilache       AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
1594dc72d47SNicolas Vasilache                                           insertSliceOp.getDestType().getRank(),
1604dc72d47SNicolas Vasilache                                           insertSliceOp.getDroppedDims())),
1614dc72d47SNicolas Vasilache       writeOp.getInBoundsAttr());
1624dc72d47SNicolas Vasilache 
1634dc72d47SNicolas Vasilache   return success();
1644dc72d47SNicolas Vasilache }
1654dc72d47SNicolas Vasilache 
166760ffa47SRajveer Singh Bharadwaj bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
167760ffa47SRajveer Singh Bharadwaj     vector::TransferWriteOp writeOp) {
168760ffa47SRajveer Singh Bharadwaj   if (writeOp.getShapedType().hasStaticShape())
169760ffa47SRajveer Singh Bharadwaj     return llvm::equal(writeOp.getVectorType().getShape(),
170760ffa47SRajveer Singh Bharadwaj                        writeOp.getShapedType().getShape());
171760ffa47SRajveer Singh Bharadwaj 
172760ffa47SRajveer Singh Bharadwaj   // TODO: Use ValueBoundsConstraintSet for dynamic shapes.
173760ffa47SRajveer Singh Bharadwaj 
174760ffa47SRajveer Singh Bharadwaj   return false;
175760ffa47SRajveer Singh Bharadwaj }
176760ffa47SRajveer Singh Bharadwaj 
17733468a51SNicolas Vasilache template <typename OpTy>
17833468a51SNicolas Vasilache struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
17933468a51SNicolas Vasilache   using OpRewritePattern<OpTy>::OpRewritePattern;
18033468a51SNicolas Vasilache 
18133468a51SNicolas Vasilache   LogicalResult matchAndRewrite(OpTy insertSliceOp,
18233468a51SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
18333468a51SNicolas Vasilache     auto sourceInsertSliceOp =
18433468a51SNicolas Vasilache         insertSliceOp.getSource()
18533468a51SNicolas Vasilache             .template getDefiningOp<tensor::InsertSliceOp>();
18633468a51SNicolas Vasilache     if (!sourceInsertSliceOp)
18733468a51SNicolas Vasilache       return failure();
18833468a51SNicolas Vasilache 
18933468a51SNicolas Vasilache     // TODO: relax unit stride assumption where possible.
19033468a51SNicolas Vasilache     if (!insertSliceOp.hasUnitStride()) {
19133468a51SNicolas Vasilache       return rewriter.notifyMatchFailure(insertSliceOp,
19233468a51SNicolas Vasilache                                          "requires unit strides");
19333468a51SNicolas Vasilache     }
19433468a51SNicolas Vasilache     if (!sourceInsertSliceOp.hasUnitStride()) {
19533468a51SNicolas Vasilache       return rewriter.notifyMatchFailure(sourceInsertSliceOp,
19633468a51SNicolas Vasilache                                          "requires unit strides");
19733468a51SNicolas Vasilache     }
19833468a51SNicolas Vasilache 
19933468a51SNicolas Vasilache     int64_t srcDim = 0;
20033468a51SNicolas Vasilache     llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
20133468a51SNicolas Vasilache     for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
20233468a51SNicolas Vasilache       if (droppedDims[d])
20333468a51SNicolas Vasilache         continue;
20433468a51SNicolas Vasilache       if (insertSliceOp.getMixedSizes()[d] !=
20533468a51SNicolas Vasilache           sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
20633468a51SNicolas Vasilache         return rewriter.notifyMatchFailure(
20733468a51SNicolas Vasilache             sourceInsertSliceOp,
20833468a51SNicolas Vasilache             "requires matching sizes to fold, otherwise a copy is needed");
20933468a51SNicolas Vasilache       }
21033468a51SNicolas Vasilache     }
21133468a51SNicolas Vasilache 
21233468a51SNicolas Vasilache     // Resolve sizes according to dropped dims.
21333468a51SNicolas Vasilache     SmallVector<OpFoldResult> resolvedSizes;
21433468a51SNicolas Vasilache     // Note: the "insertSlice" case is symmetrical to the extract/subview case:
21533468a51SNicolas Vasilache     // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
21633468a51SNicolas Vasilache     // passed as the destination to the helper function.
2174c48f016SMatthias Springer     affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
21833468a51SNicolas Vasilache                                         sourceInsertSliceOp.getMixedSizes(),
21933468a51SNicolas Vasilache                                         droppedDims, resolvedSizes);
22033468a51SNicolas Vasilache 
22133468a51SNicolas Vasilache     // If we are inside an InParallel region, temporarily set the insertion
22233468a51SNicolas Vasilache     // point outside: only tensor.parallel_insert_slice ops are allowed in
22333468a51SNicolas Vasilache     // there.
22433468a51SNicolas Vasilache     if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
22533468a51SNicolas Vasilache       rewriter.setInsertionPoint(
22633468a51SNicolas Vasilache           insertSliceOp->template getParentOfType<scf::InParallelOp>());
22733468a51SNicolas Vasilache     }
22833468a51SNicolas Vasilache 
22933468a51SNicolas Vasilache     // Resolve offsets according to source offsets and strides.
23033468a51SNicolas Vasilache     SmallVector<Value> resolvedOffsets;
23133468a51SNicolas Vasilache     // Note: the "insertSlice" case is symmetrical to the extract/subview case:
23233468a51SNicolas Vasilache     // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
23333468a51SNicolas Vasilache     // passed as the destination to the helper function.
2344c48f016SMatthias Springer     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
23533468a51SNicolas Vasilache         rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
23633468a51SNicolas Vasilache         insertSliceOp.getMixedStrides(), droppedDims,
23733468a51SNicolas Vasilache         sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
23833468a51SNicolas Vasilache 
23933468a51SNicolas Vasilache     // Reset the insertion point.
24033468a51SNicolas Vasilache     rewriter.setInsertionPoint(insertSliceOp);
24133468a51SNicolas Vasilache     // Replace original op.
24233468a51SNicolas Vasilache     rewriter.replaceOpWithNewOp<OpTy>(
24333468a51SNicolas Vasilache         insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
24433468a51SNicolas Vasilache         getAsOpFoldResult(resolvedOffsets), resolvedSizes,
24533468a51SNicolas Vasilache         insertSliceOp.getMixedStrides());
24633468a51SNicolas Vasilache 
24733468a51SNicolas Vasilache     return success();
24833468a51SNicolas Vasilache   }
24933468a51SNicolas Vasilache };
25033468a51SNicolas Vasilache 
2514dc72d47SNicolas Vasilache void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
252867afe5eSMatthias Springer   populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
253867afe5eSMatthias Springer   patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
25433468a51SNicolas Vasilache                InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
25533468a51SNicolas Vasilache       patterns.getContext());
2564dc72d47SNicolas Vasilache }
257867afe5eSMatthias Springer 
258867afe5eSMatthias Springer void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
259867afe5eSMatthias Springer     RewritePatternSet &patterns) {
260867afe5eSMatthias Springer   patterns.add<TransferReadOfExtractSliceOpFolder,
261867afe5eSMatthias Springer                InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
262867afe5eSMatthias Springer }
263867afe5eSMatthias Springer 
2644dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===//
2654dc72d47SNicolas Vasilache // Pass registration
2664dc72d47SNicolas Vasilache //===----------------------------------------------------------------------===//
2674dc72d47SNicolas Vasilache 
2684dc72d47SNicolas Vasilache namespace {
2694dc72d47SNicolas Vasilache 
2704dc72d47SNicolas Vasilache struct FoldTensorSubsetOpsPass final
2714dc72d47SNicolas Vasilache     : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
2724dc72d47SNicolas Vasilache   void runOnOperation() override;
2734dc72d47SNicolas Vasilache };
2744dc72d47SNicolas Vasilache 
2754dc72d47SNicolas Vasilache } // namespace
2764dc72d47SNicolas Vasilache 
2774dc72d47SNicolas Vasilache void FoldTensorSubsetOpsPass::runOnOperation() {
2784dc72d47SNicolas Vasilache   RewritePatternSet patterns(&getContext());
2794dc72d47SNicolas Vasilache   tensor::populateFoldTensorSubsetOpPatterns(patterns);
280*09dfc571SJacques Pienaar   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
2814dc72d47SNicolas Vasilache }
2824dc72d47SNicolas Vasilache 
2834dc72d47SNicolas Vasilache std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
2844dc72d47SNicolas Vasilache   return std::make_unique<FoldTensorSubsetOpsPass>();
2854dc72d47SNicolas Vasilache }
286