1 //===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Fold tensor subset ops with producer / consumers. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" 15 #include "mlir/Dialect/SCF/IR/SCF.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Tensor/Transforms/Passes.h" 18 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 19 #include "mlir/Dialect/Utils/IndexingUtils.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/BuiltinAttributes.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 #include "llvm/ADT/TypeSwitch.h" 25 #include <type_traits> 26 27 namespace mlir { 28 namespace tensor { 29 #define GEN_PASS_DEF_FOLDTENSORSUBSETOPS 30 #include "mlir/Dialect/Tensor/Transforms/Passes.h.inc" 31 } // namespace tensor 32 } // namespace mlir 33 34 using namespace mlir; 35 36 static Value getTensorOperand(vector::TransferReadOp op) { 37 return op.getSource(); 38 } 39 40 static Value getTensorOperand(tensor::InsertSliceOp op) { 41 return op.getSource(); 42 } 43 44 //===----------------------------------------------------------------------===// 45 // Patterns 46 //===----------------------------------------------------------------------===// 47 48 namespace { 49 /// Merge extract_slice operation with load/transferRead operation. 50 class TransferReadOfExtractSliceOpFolder final 51 : public OpRewritePattern<vector::TransferReadOp> { 52 public: 53 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 54 55 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 56 PatternRewriter &rewriter) const override; 57 }; 58 59 /// Merge insert_slice operation with store/transferWriteOp operation. 60 class InsertSliceOfTransferWriteOpFolder final 61 : public OpRewritePattern<tensor::InsertSliceOp> { 62 public: 63 using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern; 64 65 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp, 66 PatternRewriter &rewriter) const override; 67 }; 68 } // namespace 69 70 template <typename XferOp, typename ExtractOrInsertOp> 71 static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp( 72 RewriterBase &rewriter, XferOp xferOp, 73 ExtractOrInsertOp extractOrInsertSliceOp) { 74 if (xferOp.hasOutOfBoundsDim()) 75 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); 76 if (xferOp.getMask()) 77 return rewriter.notifyMatchFailure(xferOp, "masked transfer"); 78 if (!extractOrInsertSliceOp.hasUnitStride()) { 79 return rewriter.notifyMatchFailure( 80 xferOp, "non-1 stride insert/extract, requires keeping track of " 81 "strides, this may result in needing to insert " 82 "vector.insert_strided_slice/extract_strided_slice ops"); 83 } 84 return success(); 85 } 86 87 LogicalResult TransferReadOfExtractSliceOpFolder::matchAndRewrite( 88 vector::TransferReadOp readOp, PatternRewriter &rewriter) const { 89 auto extractSliceOp = 90 getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>(); 91 if (!extractSliceOp) 92 return rewriter.notifyMatchFailure(readOp, "not an extract_slice"); 93 94 LogicalResult preconditionResult = 95 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp, 96 extractSliceOp); 97 if (failed(preconditionResult)) 98 return preconditionResult; 99 100 SmallVector<Value> indices(readOp.getIndices().begin(), 101 readOp.getIndices().end()); 102 SmallVector<Value> sourceIndices; 103 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 104 rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(), 105 extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(), 106 indices, sourceIndices); 107 108 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 109 readOp, readOp.getVectorType(), extractSliceOp.getSource(), sourceIndices, 110 AffineMapAttr::get(expandDimsToRank( 111 readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(), 112 extractSliceOp.getDroppedDims())), 113 readOp.getPadding(), 114 /*mask=*/Value(), readOp.getInBoundsAttr()); 115 116 return success(); 117 } 118 119 LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite( 120 tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const { 121 auto writeOp = getTensorOperand(insertSliceOp) 122 .template getDefiningOp<vector::TransferWriteOp>(); 123 if (!writeOp) 124 return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write"); 125 126 LogicalResult preconditionResult = 127 preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp, 128 insertSliceOp); 129 if (failed(preconditionResult)) 130 return preconditionResult; 131 132 SmallVector<Value> indices(writeOp.getIndices().begin(), 133 writeOp.getIndices().end()); 134 SmallVector<Value> sourceIndices; 135 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 136 rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(), 137 insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices, 138 sourceIndices); 139 140 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 141 insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices, 142 AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(), 143 insertSliceOp.getDestType().getRank(), 144 insertSliceOp.getDroppedDims())), 145 writeOp.getInBoundsAttr()); 146 147 return success(); 148 } 149 150 template <typename OpTy> 151 struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> { 152 using OpRewritePattern<OpTy>::OpRewritePattern; 153 154 LogicalResult matchAndRewrite(OpTy insertSliceOp, 155 PatternRewriter &rewriter) const override { 156 auto sourceInsertSliceOp = 157 insertSliceOp.getSource() 158 .template getDefiningOp<tensor::InsertSliceOp>(); 159 if (!sourceInsertSliceOp) 160 return failure(); 161 162 // TODO: relax unit stride assumption where possible. 163 if (!insertSliceOp.hasUnitStride()) { 164 return rewriter.notifyMatchFailure(insertSliceOp, 165 "requires unit strides"); 166 } 167 if (!sourceInsertSliceOp.hasUnitStride()) { 168 return rewriter.notifyMatchFailure(sourceInsertSliceOp, 169 "requires unit strides"); 170 } 171 172 int64_t srcDim = 0; 173 llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims(); 174 for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) { 175 if (droppedDims[d]) 176 continue; 177 if (insertSliceOp.getMixedSizes()[d] != 178 sourceInsertSliceOp.getMixedSizes()[srcDim++]) { 179 return rewriter.notifyMatchFailure( 180 sourceInsertSliceOp, 181 "requires matching sizes to fold, otherwise a copy is needed"); 182 } 183 } 184 185 // Resolve sizes according to dropped dims. 186 SmallVector<OpFoldResult> resolvedSizes; 187 // Note: the "insertSlice" case is symmetrical to the extract/subview case: 188 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is 189 // passed as the destination to the helper function. 190 affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(), 191 sourceInsertSliceOp.getMixedSizes(), 192 droppedDims, resolvedSizes); 193 194 // If we are inside an InParallel region, temporarily set the insertion 195 // point outside: only tensor.parallel_insert_slice ops are allowed in 196 // there. 197 if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) { 198 rewriter.setInsertionPoint( 199 insertSliceOp->template getParentOfType<scf::InParallelOp>()); 200 } 201 202 // Resolve offsets according to source offsets and strides. 203 SmallVector<Value> resolvedOffsets; 204 // Note: the "insertSlice" case is symmetrical to the extract/subview case: 205 // `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is 206 // passed as the destination to the helper function. 207 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 208 rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(), 209 insertSliceOp.getMixedStrides(), droppedDims, 210 sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets); 211 212 // Reset the insertion point. 213 rewriter.setInsertionPoint(insertSliceOp); 214 // Replace original op. 215 rewriter.replaceOpWithNewOp<OpTy>( 216 insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(), 217 getAsOpFoldResult(resolvedOffsets), resolvedSizes, 218 insertSliceOp.getMixedStrides()); 219 220 return success(); 221 } 222 }; 223 224 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) { 225 patterns.add<TransferReadOfExtractSliceOpFolder, 226 InsertSliceOfTransferWriteOpFolder, 227 InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>, 228 InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>( 229 patterns.getContext()); 230 } 231 //===----------------------------------------------------------------------===// 232 // Pass registration 233 //===----------------------------------------------------------------------===// 234 235 namespace { 236 237 struct FoldTensorSubsetOpsPass final 238 : public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> { 239 void runOnOperation() override; 240 }; 241 242 } // namespace 243 244 void FoldTensorSubsetOpsPass::runOnOperation() { 245 RewritePatternSet patterns(&getContext()); 246 tensor::populateFoldTensorSubsetOpPatterns(patterns); 247 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 248 } 249 250 std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() { 251 return std::make_unique<FoldTensorSubsetOpsPass>(); 252 } 253