1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/Tensor.h" 10 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 11 #include "mlir/IR/PatternMatch.h" 12 #include "llvm/Support/Debug.h" 13 14 using namespace mlir; 15 using namespace mlir::tensor; 16 17 namespace { 18 /// Fold expand_shape(extract_slice) ops that cancel itself out. 19 struct FoldExpandOfRankReducingExtract 20 : public OpRewritePattern<ExpandShapeOp> { 21 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; 22 23 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, 24 PatternRewriter &rewriter) const override { 25 RankedTensorType resultType = expandShapeOp.getResultType(); 26 auto extractSliceOp = 27 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>(); 28 if (!extractSliceOp) 29 return failure(); 30 RankedTensorType srcType = extractSliceOp.getSourceType(); 31 32 // Only cases where the ExpandShapeOp can be folded away entirely are 33 // supported. Moreover, only simple cases where the resulting ExtractSliceOp 34 // has no rank-reduction anymore are supported at the moment. 35 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( 36 srcType, extractSliceOp.getStaticOffsets(), 37 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); 38 if (nonReducingExtractType != resultType) 39 return failure(); 40 41 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); 42 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); 43 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); 44 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( 45 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, 46 mixedStrides); 47 return success(); 48 } 49 }; 50 51 /// Fold collapse_shape which only removes static dimensions of size `1` 52 /// into extract_slice. 53 struct FoldUnPaddingCollapseIntoExtract 54 : public OpRewritePattern<tensor::CollapseShapeOp> { 55 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern; 56 57 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp, 58 PatternRewriter &rewriter) const override { 59 auto extractSliceOp = 60 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>(); 61 // Collapse cannot be folded away with multiple users of the extract slice 62 // and it is not necessarily beneficial to only convert the collapse into 63 // another extract slice. 64 if (!extractSliceOp || !extractSliceOp->hasOneUse()) 65 return failure(); 66 67 // Only fold away simple collapse where all removed dimensions have static 68 // size `1`. 69 SliceVerificationResult res = isRankReducedType( 70 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType()); 71 if (res != SliceVerificationResult::Success) 72 return rewriter.notifyMatchFailure(collapseShapeOp, 73 "expected unpadding collapse"); 74 75 Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>( 76 extractSliceOp.getLoc(), collapseShapeOp.getResultType(), 77 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(), 78 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); 79 rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice); 80 return success(); 81 } 82 }; 83 84 /// Fold insert_slice(collapse_shape) ops that cancel itself out. 85 template <typename OpTy> 86 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> { 87 using OpRewritePattern<OpTy>::OpRewritePattern; 88 89 LogicalResult matchAndRewrite(OpTy insertSliceOp, 90 PatternRewriter &rewriter) const override { 91 auto collapseShapeOp = 92 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>(); 93 if (!collapseShapeOp) 94 return failure(); 95 RankedTensorType srcType = collapseShapeOp.getSrcType(); 96 97 // Only cases where the CollapseShapeOp can be folded away entirely are 98 // supported. Moreover, only simple cases where the resulting InsertSliceOp 99 // has no rank-reduction anymore are supported at the moment. 100 RankedTensorType nonReducingInsertType = 101 RankedTensorType::get(insertSliceOp.getStaticSizes(), 102 insertSliceOp.getDestType().getElementType()); 103 if (nonReducingInsertType != srcType) 104 return failure(); 105 106 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); 107 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); 108 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); 109 rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(), 110 insertSliceOp.getDest(), mixedOffsets, 111 mixedSizes, mixedStrides); 112 return success(); 113 } 114 }; 115 116 /// Fold expand_shape which only adds static dimensions of size `1` 117 /// into insert_slice. 118 template <typename OpTy> 119 struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> { 120 using OpRewritePattern<OpTy>::OpRewritePattern; 121 122 LogicalResult matchAndRewrite(OpTy insertSliceOp, 123 PatternRewriter &rewriter) const override { 124 auto expandShapeOp = insertSliceOp.getSource() 125 .template getDefiningOp<tensor::ExpandShapeOp>(); 126 if (!expandShapeOp) 127 return failure(); 128 129 // Only fold away simple expansion where all added dimensions have static 130 // size `1`. 131 SliceVerificationResult res = isRankReducedType( 132 expandShapeOp.getResultType(), expandShapeOp.getSrcType()); 133 if (res != SliceVerificationResult::Success) 134 return rewriter.notifyMatchFailure(insertSliceOp, 135 "expected rank increasing expansion"); 136 137 rewriter.modifyOpInPlace(insertSliceOp, [&]() { 138 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc()); 139 }); 140 return success(); 141 } 142 }; 143 144 /// Pattern to bubble up a tensor.expand_shape op through a producer 145 /// tensor.collapse_shape op that has non intersecting reassociations. 146 struct BubbleUpExpandThroughParallelCollapse 147 : public OpRewritePattern<tensor::ExpandShapeOp> { 148 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern; 149 150 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, 151 PatternRewriter &rewriter) const override { 152 auto collapseOp = 153 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>(); 154 if (!collapseOp) 155 return failure(); 156 auto expandReInds = expandOp.getReassociationIndices(); 157 auto collapseReInds = collapseOp.getReassociationIndices(); 158 159 // Reshapes are parallel to each other if none of the reassociation indices 160 // have greater than 1 index for both reshapes. 161 for (auto [expandReassociation, collapseReassociation] : 162 llvm::zip_equal(expandReInds, collapseReInds)) { 163 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) 164 return failure(); 165 } 166 167 // Compute new reassociation indices and expanded/collaped shapes. 168 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds; 169 Location loc = expandOp->getLoc(); 170 SmallVector<OpFoldResult> collapseSizes = 171 tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc()); 172 SmallVector<OpFoldResult> expandSizes(getMixedValues( 173 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter)); 174 SmallVector<OpFoldResult> newExpandSizes; 175 int64_t index = 0, expandIndex = 0, collapseIndex = 0; 176 for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) { 177 if (collapseReassociation.size() != 1) { 178 ReassociationIndices newCollapseReassociation; 179 for (size_t i = 0; i < collapseReassociation.size(); ++i) { 180 newCollapseReassociation.push_back(index); 181 newExpandReInds.push_back({index++}); 182 newExpandSizes.push_back(collapseSizes[collapseIndex++]); 183 } 184 newCollapseReInds.push_back(newCollapseReassociation); 185 expandIndex++; 186 continue; 187 } 188 ReassociationIndices newExpandReassociation; 189 auto expandReassociation = expandReInds[idx]; 190 for (size_t i = 0; i < expandReassociation.size(); ++i) { 191 newExpandReassociation.push_back(index); 192 newCollapseReInds.push_back({index++}); 193 newExpandSizes.push_back(expandSizes[expandIndex++]); 194 } 195 newExpandReInds.push_back(newExpandReassociation); 196 collapseIndex++; 197 } 198 199 // Swap reshape order. 200 SmallVector<Value> dynamicSizes; 201 SmallVector<int64_t> staticSizes; 202 dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes); 203 auto expandResultType = expandOp.getResultType().clone(staticSizes); 204 auto newExpand = rewriter.create<tensor::ExpandShapeOp>( 205 loc, expandResultType, collapseOp.getSrc(), newExpandReInds, 206 newExpandSizes); 207 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( 208 expandOp, newExpand.getResult(), newCollapseReInds); 209 return success(); 210 } 211 }; 212 213 } // namespace 214 215 void mlir::tensor::populateReassociativeReshapeFoldingPatterns( 216 RewritePatternSet &patterns) { 217 patterns 218 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract, 219 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>, 220 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>, 221 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>, 222 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>( 223 patterns.getContext()); 224 } 225 226 void mlir::tensor::populateBubbleUpExpandShapePatterns( 227 RewritePatternSet &patterns) { 228 patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext()); 229 } 230