1 //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Fold tensor subset ops with producer / consumers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 19 #include "mlir/Dialect/Utils/IndexingUtils.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/BuiltinAttributes.h" 24 #include "mlir/Interfaces/ValueBoundsOpInterface.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include "llvm/ADT/TypeSwitch.h" 27 #include <type_traits> 28 29 namespace mlir { 30 namespace tensor { 31 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS 32 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" 33 } // namespace tensor 34 } // namespace mlir 35 36 using namespace mlir; 37 38 static Value getTensorOperand(vector::TransferReadOp op) { 39 return op.getSource(); 40 } 41 42 static Value getTensorOperand(tensor::InsertSliceOp op) { 43 return op.getSource(); 44 } 45 46 //===----------------------------------------------------------------------===// 47 // Patterns 48 //===----------------------------------------------------------------------===// 49 50 namespace { 51 /// Merge extract_slice operation with load/transferRead operation. 52 class TransferReadOfExtractSliceOpFolder final 53 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> { 54 public: 55 using MaskableOpRewritePattern::MaskableOpRewritePattern; 56 57 FailureOr<mlir::Value> 58 matchAndRewriteMaskableOp(vector::TransferReadOp readOp, 59 vector::MaskingOpInterface maskOp, 60 PatternRewriter &rewriter) const override; 61 }; 62 63 /// Merge insert_slice operation with store/transferWriteOp operation. 64 class InsertSliceOfTransferWriteOpFolder final 65 : public OpRewritePattern<tensor::InsertSliceOp> { 66 public: 67 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; 68 69 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp, 70 PatternRewriter &rewriter) const override; 71 72 private: 73 static bool 74 doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp); 75 }; 76 } // namespace 77 78 template <typename XferOp, typename ExtractOrInsertOp> 79 static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp( 80 RewriterBase &rewriter, XferOp xferOp, 81 ExtractOrInsertOp extractOrInsertSliceOp) { 82 if (xferOp.hasOutOfBoundsDim()) 83 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); 84 if (xferOp.getMask()) 85 return rewriter.notifyMatchFailure(xferOp, "masked transfer"); 86 if (!extractOrInsertSliceOp.hasUnitStride()) { 87 return rewriter.notifyMatchFailure( 88 xferOp, "non-1 stride insert/extract, requires keeping track of " 89 "strides, this may result in needing to insert " 90 "vector.insert_strided_slice/extract_strided_slice ops"); 91 } 92 return success(); 93 } 94 95 FailureOr<mlir::Value> 96 TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp( 97 vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp, 98 PatternRewriter &rewriter) const { 99 auto extractSliceOp = 100 getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>(); 101 if (!extractSliceOp) 102 return rewriter.notifyMatchFailure(readOp, "not an extract_slice"); 103 104 LogicalResult preconditionResult = 105 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp, 106 extractSliceOp); 107 if (failed(preconditionResult)) 108 return rewriter.notifyMatchFailure(readOp, "Failed preconditions"); 109 110 SmallVector<Value> indices(readOp.getIndices().begin(), 111 readOp.getIndices().end()); 112 SmallVector<Value> sourceIndices; 113 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 114 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(), 115 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(), 116 indices, sourceIndices); 117 118 Operation *newOp = rewriter.create<vector::TransferReadOp>( 119 readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(), 120 sourceIndices, 121 AffineMapAttr::get(expandDimsToRank( 122 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(), 123 extractSliceOp.getDroppedDims())), 124 readOp.getPadding(), 125 /*mask=*/Value(), readOp.getInBoundsAttr()); 126 if (maskOp) 127 newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask()); 128 return newOp->getResults()[0]; 129 } 130 131 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite( 132 tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const { 133 auto writeOp = getTensorOperand(insertSliceOp) 134 .template getDefiningOp<vector::TransferWriteOp>(); 135 if (!writeOp) 136 return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write"); 137 138 LogicalResult preconditionResult = 139 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp, 140 insertSliceOp); 141 if (failed(preconditionResult)) 142 return preconditionResult; 143 144 if (!doesTransferWriteCoverInsertSlice(writeOp)) 145 return rewriter.notifyMatchFailure( 146 insertSliceOp, "transfer_write does not cover insert_slice"); 147 148 SmallVector<Value> indices(writeOp.getIndices().begin(), 149 writeOp.getIndices().end()); 150 SmallVector<Value> sourceIndices; 151 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 152 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(), 153 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices, 154 sourceIndices); 155 156 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 157 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices, 158 AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(), 159 insertSliceOp.getDestType().getRank(), 160 insertSliceOp.getDroppedDims())), 161 writeOp.getInBoundsAttr()); 162 163 return success(); 164 } 165 166 bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice( 167 vector::TransferWriteOp writeOp) { 168 if (writeOp.getShapedType().hasStaticShape()) 169 return llvm::equal(writeOp.getVectorType().getShape(), 170 writeOp.getShapedType().getShape()); 171 172 // TODO: Use ValueBoundsConstraintSet for dynamic shapes. 173 174 return false; 175 } 176 177 template <typename OpTy> 178 struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> { 179 using OpRewritePattern<OpTy>::OpRewritePattern; 180 181 LogicalResult matchAndRewrite(OpTy insertSliceOp, 182 PatternRewriter &rewriter) const override { 183 auto sourceInsertSliceOp = 184 insertSliceOp.getSource() 185 .template getDefiningOp<tensor::InsertSliceOp>(); 186 if (!sourceInsertSliceOp) 187 return failure(); 188 189 // TODO: relax unit stride assumption where possible. 190 if (!insertSliceOp.hasUnitStride()) { 191 return rewriter.notifyMatchFailure(insertSliceOp, 192 "requires unit strides"); 193 } 194 if (!sourceInsertSliceOp.hasUnitStride()) { 195 return rewriter.notifyMatchFailure(sourceInsertSliceOp, 196 "requires unit strides"); 197 } 198 199 int64_t srcDim = 0; 200 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims(); 201 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) { 202 if (droppedDims[d]) 203 continue; 204 if (insertSliceOp.getMixedSizes()[d] != 205 sourceInsertSliceOp.getMixedSizes()[srcDim++]) { 206 return rewriter.notifyMatchFailure( 207 sourceInsertSliceOp, 208 "requires matching sizes to fold, otherwise a copy is needed"); 209 } 210 } 211 212 // Resolve sizes according to dropped dims. 213 SmallVector<OpFoldResult> resolvedSizes; 214 // Note: the "insertSlice" case is symmetrical to the extract/subview case: 215 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is 216 // passed as the destination to the helper function. 217 affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(), 218 sourceInsertSliceOp.getMixedSizes(), 219 droppedDims, resolvedSizes); 220 221 // If we are inside an InParallel region, temporarily set the insertion 222 // point outside: only tensor.parallel_insert_slice ops are allowed in 223 // there. 224 if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) { 225 rewriter.setInsertionPoint( 226 insertSliceOp->template getParentOfType<scf::InParallelOp>()); 227 } 228 229 // Resolve offsets according to source offsets and strides. 230 SmallVector<Value> resolvedOffsets; 231 // Note: the "insertSlice" case is symmetrical to the extract/subview case: 232 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is 233 // passed as the destination to the helper function. 234 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 235 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(), 236 insertSliceOp.getMixedStrides(), droppedDims, 237 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets); 238 239 // Reset the insertion point. 240 rewriter.setInsertionPoint(insertSliceOp); 241 // Replace original op. 242 rewriter.replaceOpWithNewOp<OpTy>( 243 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(), 244 getAsOpFoldResult(resolvedOffsets), resolvedSizes, 245 insertSliceOp.getMixedStrides()); 246 247 return success(); 248 } 249 }; 250 251 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { 252 populateFoldTensorSubsetIntoVectorTransferPatterns(patterns); 253 patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>, 254 InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>( 255 patterns.getContext()); 256 } 257 258 void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns( 259 RewritePatternSet &patterns) { 260 patterns.add<TransferReadOfExtractSliceOpFolder, 261 InsertSliceOfTransferWriteOpFolder>(patterns.getContext()); 262 } 263 264 //===----------------------------------------------------------------------===// 265 // Pass registration 266 //===----------------------------------------------------------------------===// 267 268 namespace { 269 270 struct FoldTensorSubsetOpsPass final 271 : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> { 272 void runOnOperation() override; 273 }; 274 275 } // namespace 276 277 void FoldTensorSubsetOpsPass::runOnOperation() { 278 RewritePatternSet patterns(&getContext()); 279 tensor::populateFoldTensorSubsetOpPatterns(patterns); 280 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 281 } 282 283 std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() { 284 return std::make_unique<FoldTensorSubsetOpsPass>(); 285 } 286