1 //===- MergeConsecutiveInsertExtractSlicePatterns.cpp ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Tensor/IR/Tensor.h" 11 #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" 12 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 13 #include "mlir/IR/BuiltinTypes.h" 14 #include "mlir/IR/OpDefinition.h" 15 #include "mlir/IR/PatternMatch.h" 16 17 using namespace mlir; 18 using namespace mlir::tensor; 19 20 /// Creates AffineExpr from `ofr`: if the OpFoldResult is a Value, creates a 21 /// AffineSymbolExpr and appends it to `symbols`; otherwise creates a 22 /// AffineConstantExpr. 23 static AffineExpr getAffineExpr(OpFoldResult ofr, 24 SmallVector<OpFoldResult> &symbols) { 25 if (auto attr = ofr.dyn_cast<Attribute>()) { 26 return getAffineConstantExpr(attr.cast<IntegerAttr>().getInt(), 27 attr.getContext()); 28 } 29 Value v = ofr.get<Value>(); 30 AffineExpr expr = getAffineSymbolExpr(symbols.size(), v.getContext()); 31 symbols.push_back(v); 32 return expr; 33 } 34 35 /// Builds the AffineExpr incrementally for arithmetic operations. 36 static AffineExpr add(AffineExpr expr, OpFoldResult ofr, 37 SmallVector<OpFoldResult> &symbols) { 38 return expr + getAffineExpr(ofr, symbols); 39 } 40 static AffineExpr mul(OpFoldResult lhs, OpFoldResult rhs, 41 SmallVector<OpFoldResult> &symbols) { 42 return getAffineExpr(lhs, symbols) * getAffineExpr(rhs, symbols); 43 } 44 45 /// Converts an AffineExpr to OpFoldResult by generating an `affine.apply` 46 /// op and fold it. 47 static OpFoldResult getOpFoldResult(OpBuilder &builder, Location loc, 48 AffineExpr expr, 49 SmallVector<OpFoldResult> &symbols) { 50 AffineMap m = AffineMap::get(0, symbols.size(), expr); 51 return makeComposedFoldedAffineApply(builder, loc, m, symbols); 52 } 53 54 LogicalResult tensor::mergeOffsetsSizesAndStrides( 55 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> producerOffsets, 56 ArrayRef<OpFoldResult> producerSizes, 57 ArrayRef<OpFoldResult> producerStrides, 58 const llvm::SmallBitVector &droppedProducerDims, 59 ArrayRef<OpFoldResult> consumerOffsets, 60 ArrayRef<OpFoldResult> consumerSizes, 61 ArrayRef<OpFoldResult> consumerStrides, 62 SmallVector<OpFoldResult> &combinedOffsets, 63 SmallVector<OpFoldResult> &combinedSizes, 64 SmallVector<OpFoldResult> &combinedStrides) { 65 combinedOffsets.resize(producerOffsets.size()); 66 combinedSizes.resize(producerOffsets.size()); 67 combinedStrides.resize(producerOffsets.size()); 68 unsigned consumerPos = 0; 69 for (auto i : llvm::seq<unsigned>(0, producerOffsets.size())) { 70 if (droppedProducerDims.test(i)) { 71 // For dropped dims, get the values from the producer. 72 combinedOffsets[i] = producerOffsets[i]; 73 combinedSizes[i] = producerSizes[i]; 74 combinedStrides[i] = producerStrides[i]; 75 continue; 76 } 77 SmallVector<OpFoldResult> offsetSymbols, strideSymbols; 78 // The combined offset is computed as 79 // producer_offset + consumer_offset * producer_strides. 80 combinedOffsets[i] = 81 getOpFoldResult(builder, loc, 82 add(mul(consumerOffsets[consumerPos], 83 producerStrides[i], offsetSymbols), 84 producerOffsets[i], offsetSymbols), 85 offsetSymbols); 86 combinedSizes[i] = consumerSizes[consumerPos]; 87 // The combined stride is computed as 88 // consumer_stride * producer_stride. 89 combinedStrides[i] = getOpFoldResult( 90 builder, loc, 91 mul(consumerStrides[consumerPos], producerStrides[i], strideSymbols), 92 strideSymbols); 93 consumerPos++; 94 } 95 return success(); 96 } 97 98 LogicalResult tensor::mergeOffsetsSizesAndStrides( 99 OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer, 100 OffsetSizeAndStrideOpInterface consumer, 101 const llvm::SmallBitVector &droppedProducerDims, 102 SmallVector<OpFoldResult> &combinedOffsets, 103 SmallVector<OpFoldResult> &combinedSizes, 104 SmallVector<OpFoldResult> &combinedStrides) { 105 SmallVector<OpFoldResult> consumerOffsets = consumer.getMixedOffsets(); 106 SmallVector<OpFoldResult> consumerSizes = consumer.getMixedSizes(); 107 SmallVector<OpFoldResult> consumerStrides = consumer.getMixedStrides(); 108 SmallVector<OpFoldResult> producerOffsets = producer.getMixedOffsets(); 109 SmallVector<OpFoldResult> producerSizes = producer.getMixedSizes(); 110 SmallVector<OpFoldResult> producerStrides = producer.getMixedStrides(); 111 return tensor::mergeOffsetsSizesAndStrides( 112 builder, loc, producerOffsets, producerSizes, producerStrides, 113 droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides, 114 combinedOffsets, combinedSizes, combinedStrides); 115 } 116 117 namespace { 118 /// Merges consecutive tensor.extract_slice ops into one. 119 struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> { 120 using OpRewritePattern::OpRewritePattern; 121 122 LogicalResult matchAndRewrite(ExtractSliceOp nextOp, 123 PatternRewriter &rewriter) const override { 124 auto prevOp = nextOp.getSource().getDefiningOp<ExtractSliceOp>(); 125 if (!prevOp) 126 return failure(); 127 128 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides; 129 if (failed(mergeOffsetsSizesAndStrides(rewriter, nextOp.getLoc(), prevOp, 130 nextOp, prevOp.getDroppedDims(), 131 newOffsets, newSizes, newStrides))) 132 return failure(); 133 134 rewriter.replaceOpWithNewOp<ExtractSliceOp>(nextOp, nextOp.getType(), 135 prevOp.getSource(), newOffsets, 136 newSizes, newStrides); 137 return success(); 138 } 139 }; 140 141 /// Merges consecutive tensor.insert_slice ops into one. 142 struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> { 143 using OpRewritePattern::OpRewritePattern; 144 145 LogicalResult matchAndRewrite(InsertSliceOp nextOp, 146 PatternRewriter &rewriter) const override { 147 auto prevOp = nextOp.getSource().getDefiningOp<InsertSliceOp>(); 148 if (!prevOp) 149 return failure(); 150 151 if (!prevOp.hasUnitStride() || !nextOp.hasUnitStride()) 152 return failure(); 153 154 // The first insert_slice op should be rank reducing to make sure we cover 155 // the full source tensor to be inserted in the second insert_slice op. 156 SliceVerificationResult result = 157 isRankReducedType(prevOp.getDestType(), prevOp.getSourceType()); 158 if (result != SliceVerificationResult::Success) 159 return failure(); 160 161 // Dynamic dimensions can pass rank reducing check in the above, e.g, 162 // inserting <?xf32> into <1x?x1xf32>. For such cases we cannot be certain 163 // the dynamic size covers the full tensor. 164 if (!prevOp.getSourceType().hasStaticShape() || 165 !prevOp.getDestType().hasStaticShape()) 166 return failure(); 167 168 rewriter.replaceOpWithNewOp<InsertSliceOp>( 169 nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(), 170 nextOp.getMixedSizes(), nextOp.getMixedStrides()); 171 return success(); 172 } 173 }; 174 } // namespace 175 176 void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns( 177 RewritePatternSet &patterns) { 178 patterns.add<MergeConsecutiveExtractSlice, MergeConsecutiveInsertSlice>( 179 patterns.getContext()); 180 } 181