150a2bb95SMatthias Springer //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// 250a2bb95SMatthias Springer // 350a2bb95SMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 450a2bb95SMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 550a2bb95SMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 650a2bb95SMatthias Springer // 750a2bb95SMatthias Springer //===----------------------------------------------------------------------===// 850a2bb95SMatthias Springer 950a2bb95SMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 1050a2bb95SMatthias Springer #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 1150a2bb95SMatthias Springer #include "mlir/IR/PatternMatch.h" 1250a2bb95SMatthias Springer #include "llvm/Support/Debug.h" 1350a2bb95SMatthias Springer 1450a2bb95SMatthias Springer using namespace mlir; 1550a2bb95SMatthias Springer using namespace mlir::tensor; 1650a2bb95SMatthias Springer 1750a2bb95SMatthias Springer namespace { 1850a2bb95SMatthias Springer /// Fold expand_shape(extract_slice) ops that cancel itself out. 1950a2bb95SMatthias Springer struct FoldExpandOfRankReducingExtract 2050a2bb95SMatthias Springer : public OpRewritePattern<ExpandShapeOp> { 2150a2bb95SMatthias Springer using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; 2250a2bb95SMatthias Springer 2350a2bb95SMatthias Springer LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, 2450a2bb95SMatthias Springer PatternRewriter &rewriter) const override { 2550a2bb95SMatthias Springer RankedTensorType resultType = expandShapeOp.getResultType(); 2650a2bb95SMatthias Springer auto extractSliceOp = 2750a2bb95SMatthias Springer expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>(); 2850a2bb95SMatthias Springer if (!extractSliceOp) 2950a2bb95SMatthias Springer return failure(); 3050a2bb95SMatthias Springer RankedTensorType srcType = extractSliceOp.getSourceType(); 3150a2bb95SMatthias Springer 3250a2bb95SMatthias Springer // Only cases where the ExpandShapeOp can be folded away entirely are 3350a2bb95SMatthias Springer // supported. Moreover, only simple cases where the resulting ExtractSliceOp 3450a2bb95SMatthias Springer // has no rank-reduction anymore are supported at the moment. 3550a2bb95SMatthias Springer RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( 3650a2bb95SMatthias Springer srcType, extractSliceOp.getStaticOffsets(), 3750a2bb95SMatthias Springer extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); 3850a2bb95SMatthias Springer if (nonReducingExtractType != resultType) 3950a2bb95SMatthias Springer return failure(); 4050a2bb95SMatthias Springer 4150a2bb95SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 4250a2bb95SMatthias Springer SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 4350a2bb95SMatthias Springer SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 4450a2bb95SMatthias Springer rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( 4550a2bb95SMatthias Springer expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, 4650a2bb95SMatthias Springer mixedStrides); 4750a2bb95SMatthias Springer return success(); 4850a2bb95SMatthias Springer } 4950a2bb95SMatthias Springer }; 5014030737SMatthias Springer 518f4d5a32SAdam Siemieniuk /// Fold collapse_shape which only removes static dimensions of size `1` 528f4d5a32SAdam Siemieniuk /// into extract_slice. 538f4d5a32SAdam Siemieniuk struct FoldUnPaddingCollapseIntoExtract 548f4d5a32SAdam Siemieniuk : public OpRewritePattern<tensor::CollapseShapeOp> { 558f4d5a32SAdam Siemieniuk using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern; 568f4d5a32SAdam Siemieniuk 578f4d5a32SAdam Siemieniuk LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp, 588f4d5a32SAdam Siemieniuk PatternRewriter &rewriter) const override { 598f4d5a32SAdam Siemieniuk auto extractSliceOp = 608f4d5a32SAdam Siemieniuk collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>(); 618f4d5a32SAdam Siemieniuk // Collapse cannot be folded away with multiple users of the extract slice 628f4d5a32SAdam Siemieniuk // and it is not necessarily beneficial to only convert the collapse into 638f4d5a32SAdam Siemieniuk // another extract slice. 648f4d5a32SAdam Siemieniuk if (!extractSliceOp || !extractSliceOp->hasOneUse()) 658f4d5a32SAdam Siemieniuk return failure(); 668f4d5a32SAdam Siemieniuk 678f4d5a32SAdam Siemieniuk // Only fold away simple collapse where all removed dimensions have static 688f4d5a32SAdam Siemieniuk // size `1`. 698f4d5a32SAdam Siemieniuk SliceVerificationResult res = isRankReducedType( 708f4d5a32SAdam Siemieniuk collapseShapeOp.getSrcType(), collapseShapeOp.getResultType()); 718f4d5a32SAdam Siemieniuk if (res != SliceVerificationResult::Success) 728f4d5a32SAdam Siemieniuk return rewriter.notifyMatchFailure(collapseShapeOp, 738f4d5a32SAdam Siemieniuk "expected unpadding collapse"); 748f4d5a32SAdam Siemieniuk 758f4d5a32SAdam Siemieniuk Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>( 768f4d5a32SAdam Siemieniuk extractSliceOp.getLoc(), collapseShapeOp.getResultType(), 778f4d5a32SAdam Siemieniuk extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(), 788f4d5a32SAdam Siemieniuk extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); 798f4d5a32SAdam Siemieniuk rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice); 808f4d5a32SAdam Siemieniuk return success(); 818f4d5a32SAdam Siemieniuk } 828f4d5a32SAdam Siemieniuk }; 838f4d5a32SAdam Siemieniuk 8414030737SMatthias Springer /// Fold insert_slice(collapse_shape) ops that cancel itself out. 859cdf6b64SMatthias Springer template <typename OpTy> 869cdf6b64SMatthias Springer struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> { 879cdf6b64SMatthias Springer using OpRewritePattern<OpTy>::OpRewritePattern; 8814030737SMatthias Springer 899cdf6b64SMatthias Springer LogicalResult matchAndRewrite(OpTy insertSliceOp, 9014030737SMatthias Springer PatternRewriter &rewriter) const override { 9114030737SMatthias Springer auto collapseShapeOp = 929cdf6b64SMatthias Springer insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>(); 9314030737SMatthias Springer if (!collapseShapeOp) 9414030737SMatthias Springer return failure(); 9514030737SMatthias Springer RankedTensorType srcType = collapseShapeOp.getSrcType(); 9614030737SMatthias Springer 9714030737SMatthias Springer // Only cases where the CollapseShapeOp can be folded away entirely are 9814030737SMatthias Springer // supported. Moreover, only simple cases where the resulting InsertSliceOp 9914030737SMatthias Springer // has no rank-reduction anymore are supported at the moment. 10014030737SMatthias Springer RankedTensorType nonReducingInsertType = 10114030737SMatthias Springer RankedTensorType::get(insertSliceOp.getStaticSizes(), 1029cdf6b64SMatthias Springer insertSliceOp.getDestType().getElementType()); 10314030737SMatthias Springer if (nonReducingInsertType != srcType) 10414030737SMatthias Springer return failure(); 10514030737SMatthias Springer 10614030737SMatthias Springer SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 10714030737SMatthias Springer SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 10814030737SMatthias Springer SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 1099cdf6b64SMatthias Springer rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(), 1109cdf6b64SMatthias Springer insertSliceOp.getDest(), mixedOffsets, 1119cdf6b64SMatthias Springer mixedSizes, mixedStrides); 11214030737SMatthias Springer return success(); 11314030737SMatthias Springer } 11414030737SMatthias Springer }; 115d6541fc7SAdam Siemieniuk 116d6541fc7SAdam Siemieniuk /// Fold expand_shape which only adds static dimensions of size `1` 117d6541fc7SAdam Siemieniuk /// into insert_slice. 118d6541fc7SAdam Siemieniuk template <typename OpTy> 119d6541fc7SAdam Siemieniuk struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> { 120d6541fc7SAdam Siemieniuk using OpRewritePattern<OpTy>::OpRewritePattern; 121d6541fc7SAdam Siemieniuk 122d6541fc7SAdam Siemieniuk LogicalResult matchAndRewrite(OpTy insertSliceOp, 123d6541fc7SAdam Siemieniuk PatternRewriter &rewriter) const override { 124d6541fc7SAdam Siemieniuk auto expandShapeOp = insertSliceOp.getSource() 125d6541fc7SAdam Siemieniuk .template getDefiningOp<tensor::ExpandShapeOp>(); 126d6541fc7SAdam Siemieniuk if (!expandShapeOp) 127d6541fc7SAdam Siemieniuk return failure(); 128d6541fc7SAdam Siemieniuk 129d6541fc7SAdam Siemieniuk // Only fold away simple expansion where all added dimensions have static 130d6541fc7SAdam Siemieniuk // size `1`. 131d6541fc7SAdam Siemieniuk SliceVerificationResult res = isRankReducedType( 132d6541fc7SAdam Siemieniuk expandShapeOp.getResultType(), expandShapeOp.getSrcType()); 133d6541fc7SAdam Siemieniuk if (res != SliceVerificationResult::Success) 134d6541fc7SAdam Siemieniuk return rewriter.notifyMatchFailure(insertSliceOp, 135d6541fc7SAdam Siemieniuk "expected rank increasing expansion"); 136d6541fc7SAdam Siemieniuk 137d6541fc7SAdam Siemieniuk rewriter.modifyOpInPlace(insertSliceOp, [&]() { 138d6541fc7SAdam Siemieniuk insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc()); 139d6541fc7SAdam Siemieniuk }); 140d6541fc7SAdam Siemieniuk return success(); 141d6541fc7SAdam Siemieniuk } 142d6541fc7SAdam Siemieniuk }; 143*a95ad2daSIan Wood 144*a95ad2daSIan Wood /// Pattern to bubble up a tensor.expand_shape op through a producer 145*a95ad2daSIan Wood /// tensor.collapse_shape op that has non intersecting reassociations. 146*a95ad2daSIan Wood struct BubbleUpExpandThroughParallelCollapse 147*a95ad2daSIan Wood : public OpRewritePattern<tensor::ExpandShapeOp> { 148*a95ad2daSIan Wood using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern; 149*a95ad2daSIan Wood 150*a95ad2daSIan Wood LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, 151*a95ad2daSIan Wood PatternRewriter &rewriter) const override { 152*a95ad2daSIan Wood auto collapseOp = 153*a95ad2daSIan Wood expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>(); 154*a95ad2daSIan Wood if (!collapseOp) 155*a95ad2daSIan Wood return failure(); 156*a95ad2daSIan Wood auto expandReInds = expandOp.getReassociationIndices(); 157*a95ad2daSIan Wood auto collapseReInds = collapseOp.getReassociationIndices(); 158*a95ad2daSIan Wood 159*a95ad2daSIan Wood // Reshapes are parallel to each other if none of the reassociation indices 160*a95ad2daSIan Wood // have greater than 1 index for both reshapes. 161*a95ad2daSIan Wood for (auto [expandReassociation, collapseReassociation] : 162*a95ad2daSIan Wood llvm::zip_equal(expandReInds, collapseReInds)) { 163*a95ad2daSIan Wood if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) 164*a95ad2daSIan Wood return failure(); 165*a95ad2daSIan Wood } 166*a95ad2daSIan Wood 167*a95ad2daSIan Wood // Compute new reassociation indices and expanded/collaped shapes. 168*a95ad2daSIan Wood SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds; 169*a95ad2daSIan Wood Location loc = expandOp->getLoc(); 170*a95ad2daSIan Wood SmallVector<OpFoldResult> collapseSizes = 171*a95ad2daSIan Wood tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc()); 172*a95ad2daSIan Wood SmallVector<OpFoldResult> expandSizes(getMixedValues( 173*a95ad2daSIan Wood expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter)); 174*a95ad2daSIan Wood SmallVector<OpFoldResult> newExpandSizes; 175*a95ad2daSIan Wood int64_t index = 0, expandIndex = 0, collapseIndex = 0; 176*a95ad2daSIan Wood for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) { 177*a95ad2daSIan Wood if (collapseReassociation.size() != 1) { 178*a95ad2daSIan Wood ReassociationIndices newCollapseReassociation; 179*a95ad2daSIan Wood for (size_t i = 0; i < collapseReassociation.size(); ++i) { 180*a95ad2daSIan Wood newCollapseReassociation.push_back(index); 181*a95ad2daSIan Wood newExpandReInds.push_back({index++}); 182*a95ad2daSIan Wood newExpandSizes.push_back(collapseSizes[collapseIndex++]); 183*a95ad2daSIan Wood } 184*a95ad2daSIan Wood newCollapseReInds.push_back(newCollapseReassociation); 185*a95ad2daSIan Wood expandIndex++; 186*a95ad2daSIan Wood continue; 187*a95ad2daSIan Wood } 188*a95ad2daSIan Wood ReassociationIndices newExpandReassociation; 189*a95ad2daSIan Wood auto expandReassociation = expandReInds[idx]; 190*a95ad2daSIan Wood for (size_t i = 0; i < expandReassociation.size(); ++i) { 191*a95ad2daSIan Wood newExpandReassociation.push_back(index); 192*a95ad2daSIan Wood newCollapseReInds.push_back({index++}); 193*a95ad2daSIan Wood newExpandSizes.push_back(expandSizes[expandIndex++]); 194*a95ad2daSIan Wood } 195*a95ad2daSIan Wood newExpandReInds.push_back(newExpandReassociation); 196*a95ad2daSIan Wood collapseIndex++; 197*a95ad2daSIan Wood } 198*a95ad2daSIan Wood 199*a95ad2daSIan Wood // Swap reshape order. 200*a95ad2daSIan Wood SmallVector<Value> dynamicSizes; 201*a95ad2daSIan Wood SmallVector<int64_t> staticSizes; 202*a95ad2daSIan Wood dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes); 203*a95ad2daSIan Wood auto expandResultType = expandOp.getResultType().clone(staticSizes); 204*a95ad2daSIan Wood auto newExpand = rewriter.create<tensor::ExpandShapeOp>( 205*a95ad2daSIan Wood loc, expandResultType, collapseOp.getSrc(), newExpandReInds, 206*a95ad2daSIan Wood newExpandSizes); 207*a95ad2daSIan Wood rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( 208*a95ad2daSIan Wood expandOp, newExpand.getResult(), newCollapseReInds); 209*a95ad2daSIan Wood return success(); 210*a95ad2daSIan Wood } 211*a95ad2daSIan Wood }; 212*a95ad2daSIan Wood 21350a2bb95SMatthias Springer } // namespace 21450a2bb95SMatthias Springer 21550a2bb95SMatthias Springer void mlir::tensor::populateReassociativeReshapeFoldingPatterns( 21650a2bb95SMatthias Springer RewritePatternSet &patterns) { 2178f4d5a32SAdam Siemieniuk patterns 2188f4d5a32SAdam Siemieniuk .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract, 2199cdf6b64SMatthias Springer FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>, 220d6541fc7SAdam Siemieniuk FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>, 221d6541fc7SAdam Siemieniuk FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>, 222d6541fc7SAdam Siemieniuk FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>( 22314030737SMatthias Springer patterns.getContext()); 22450a2bb95SMatthias Springer } 225*a95ad2daSIan Wood 226*a95ad2daSIan Wood void mlir::tensor::populateBubbleUpExpandShapePatterns( 227*a95ad2daSIan Wood RewritePatternSet &patterns) { 228*a95ad2daSIan Wood patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext()); 229*a95ad2daSIan Wood } 230