//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::tensor; namespace { /// Fold expand_shape(extract_slice) ops that cancel itself out. struct FoldExpandOfRankReducingExtract : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, PatternRewriter &rewriter) const override { RankedTensorType resultType = expandShapeOp.getResultType(); auto extractSliceOp = expandShapeOp.getSrc().getDefiningOp(); if (!extractSliceOp) return failure(); RankedTensorType srcType = extractSliceOp.getSourceType(); // Only cases where the ExpandShapeOp can be folded away entirely are // supported. Moreover, only simple cases where the resulting ExtractSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( srcType, extractSliceOp.getStaticOffsets(), extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); if (nonReducingExtractType != resultType) return failure(); SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); rewriter.replaceOpWithNewOp( expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, mixedStrides); return success(); } }; /// Fold collapse_shape which only removes static dimensions of size `1` /// into extract_slice. struct FoldUnPaddingCollapseIntoExtract : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp, PatternRewriter &rewriter) const override { auto extractSliceOp = collapseShapeOp.getSrc().getDefiningOp(); // Collapse cannot be folded away with multiple users of the extract slice // and it is not necessarily beneficial to only convert the collapse into // another extract slice. if (!extractSliceOp || !extractSliceOp->hasOneUse()) return failure(); // Only fold away simple collapse where all removed dimensions have static // size `1`. SliceVerificationResult res = isRankReducedType( collapseShapeOp.getSrcType(), collapseShapeOp.getResultType()); if (res != SliceVerificationResult::Success) return rewriter.notifyMatchFailure(collapseShapeOp, "expected unpadding collapse"); Value unPaddedExtractSlice = rewriter.create( extractSliceOp.getLoc(), collapseShapeOp.getResultType(), extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice); return success(); } }; /// Fold insert_slice(collapse_shape) ops that cancel itself out. template struct FoldInsertOfRankReducingInsert : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override { auto collapseShapeOp = insertSliceOp.getSource().template getDefiningOp(); if (!collapseShapeOp) return failure(); RankedTensorType srcType = collapseShapeOp.getSrcType(); // Only cases where the CollapseShapeOp can be folded away entirely are // supported. Moreover, only simple cases where the resulting InsertSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingInsertType = RankedTensorType::get(insertSliceOp.getStaticSizes(), insertSliceOp.getDestType().getElementType()); if (nonReducingInsertType != srcType) return failure(); SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); SmallVector mixedSizes = insertSliceOp.getMixedSizes(); SmallVector mixedStrides = insertSliceOp.getMixedStrides(); rewriter.replaceOpWithNewOp(insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(), mixedOffsets, mixedSizes, mixedStrides); return success(); } }; /// Fold expand_shape which only adds static dimensions of size `1` /// into insert_slice. template struct FoldPaddingExpandIntoInsert : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override { auto expandShapeOp = insertSliceOp.getSource() .template getDefiningOp(); if (!expandShapeOp) return failure(); // Only fold away simple expansion where all added dimensions have static // size `1`. SliceVerificationResult res = isRankReducedType( expandShapeOp.getResultType(), expandShapeOp.getSrcType()); if (res != SliceVerificationResult::Success) return rewriter.notifyMatchFailure(insertSliceOp, "expected rank increasing expansion"); rewriter.modifyOpInPlace(insertSliceOp, [&]() { insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc()); }); return success(); } }; /// Pattern to bubble up a tensor.expand_shape op through a producer /// tensor.collapse_shape op that has non intersecting reassociations. struct BubbleUpExpandThroughParallelCollapse : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, PatternRewriter &rewriter) const override { auto collapseOp = expandOp.getSrc().getDefiningOp(); if (!collapseOp) return failure(); auto expandReInds = expandOp.getReassociationIndices(); auto collapseReInds = collapseOp.getReassociationIndices(); // Reshapes are parallel to each other if none of the reassociation indices // have greater than 1 index for both reshapes. for (auto [expandReassociation, collapseReassociation] : llvm::zip_equal(expandReInds, collapseReInds)) { if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) return failure(); } // Compute new reassociation indices and expanded/collaped shapes. SmallVector newExpandReInds, newCollapseReInds; Location loc = expandOp->getLoc(); SmallVector collapseSizes = tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc()); SmallVector expandSizes(getMixedValues( expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter)); SmallVector newExpandSizes; int64_t index = 0, expandIndex = 0, collapseIndex = 0; for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) { if (collapseReassociation.size() != 1) { ReassociationIndices newCollapseReassociation; for (size_t i = 0; i < collapseReassociation.size(); ++i) { newCollapseReassociation.push_back(index); newExpandReInds.push_back({index++}); newExpandSizes.push_back(collapseSizes[collapseIndex++]); } newCollapseReInds.push_back(newCollapseReassociation); expandIndex++; continue; } ReassociationIndices newExpandReassociation; auto expandReassociation = expandReInds[idx]; for (size_t i = 0; i < expandReassociation.size(); ++i) { newExpandReassociation.push_back(index); newCollapseReInds.push_back({index++}); newExpandSizes.push_back(expandSizes[expandIndex++]); } newExpandReInds.push_back(newExpandReassociation); collapseIndex++; } // Swap reshape order. SmallVector dynamicSizes; SmallVector staticSizes; dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes); auto expandResultType = expandOp.getResultType().clone(staticSizes); auto newExpand = rewriter.create( loc, expandResultType, collapseOp.getSrc(), newExpandReInds, newExpandSizes); rewriter.replaceOpWithNewOp( expandOp, newExpand.getResult(), newCollapseReInds); return success(); } }; } // namespace void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { patterns .add, FoldInsertOfRankReducingInsert, FoldPaddingExpandIntoInsert, FoldPaddingExpandIntoInsert>( patterns.getContext()); } void mlir::tensor::populateBubbleUpExpandShapePatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }