1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/Dialect/Utils/IndexingUtils.h" 13 #include "mlir/Dialect/Vector/IR/VectorOps.h" 14 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 16 #include "mlir/IR/BuiltinTypes.h" 17 18 using namespace mlir; 19 using namespace mlir::vector; 20 21 // Helper that picks the proper sequence for inserting. 22 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 23 Value into, int64_t offset) { 24 auto vectorType = into.getType().cast<VectorType>(); 25 if (vectorType.getRank() > 1) 26 return rewriter.create<InsertOp>(loc, from, into, offset); 27 return rewriter.create<vector::InsertElementOp>( 28 loc, vectorType, from, into, 29 rewriter.create<arith::ConstantIndexOp>(loc, offset)); 30 } 31 32 // Helper that picks the proper sequence for extracting. 33 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 34 int64_t offset) { 35 auto vectorType = vector.getType().cast<VectorType>(); 36 if (vectorType.getRank() > 1) 37 return rewriter.create<ExtractOp>(loc, vector, offset); 38 return rewriter.create<vector::ExtractElementOp>( 39 loc, vectorType.getElementType(), vector, 40 rewriter.create<arith::ConstantIndexOp>(loc, offset)); 41 } 42 43 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 44 /// have different ranks. 45 /// 46 /// When ranks are different, InsertStridedSlice needs to extract a properly 47 /// ranked vector from the destination vector into which to insert. This pattern 48 /// only takes care of this extraction part and forwards the rest to 49 /// [VectorInsertStridedSliceOpSameRankRewritePattern]. 50 /// 51 /// For a k-D source and n-D destination vector (k < n), we emit: 52 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to 53 /// insert the k-D source. 54 /// 2. k-D -> (n-1)-D InsertStridedSlice op 55 /// 3. InsertOp that is the reverse of 1. 56 class VectorInsertStridedSliceOpDifferentRankRewritePattern 57 : public OpRewritePattern<InsertStridedSliceOp> { 58 public: 59 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 60 61 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 62 PatternRewriter &rewriter) const override { 63 auto srcType = op.getSourceVectorType(); 64 auto dstType = op.getDestVectorType(); 65 66 if (op.offsets().getValue().empty()) 67 return failure(); 68 69 auto loc = op.getLoc(); 70 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 71 assert(rankDiff >= 0); 72 if (rankDiff == 0) 73 return failure(); 74 75 int64_t rankRest = dstType.getRank() - rankDiff; 76 // Extract / insert the subvector of matching rank and InsertStridedSlice 77 // on it. 78 Value extracted = 79 rewriter.create<ExtractOp>(loc, op.dest(), 80 getI64SubArray(op.offsets(), /*dropFront=*/0, 81 /*dropBack=*/rankRest)); 82 83 // A different pattern will kick in for InsertStridedSlice with matching 84 // ranks. 85 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 86 loc, op.source(), extracted, 87 getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), 88 getI64SubArray(op.strides(), /*dropFront=*/0)); 89 90 rewriter.replaceOpWithNewOp<InsertOp>( 91 op, stridedSliceInnerOp.getResult(), op.dest(), 92 getI64SubArray(op.offsets(), /*dropFront=*/0, 93 /*dropBack=*/rankRest)); 94 return success(); 95 } 96 }; 97 98 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 99 /// have the same rank. For each outermost index in the slice: 100 /// begin end stride 101 /// [offset : offset+size*stride : stride] 102 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. 103 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D 104 /// 3. the destination subvector is inserted back in the proper place 105 /// 3. InsertOp that is the reverse of 1. 106 class VectorInsertStridedSliceOpSameRankRewritePattern 107 : public OpRewritePattern<InsertStridedSliceOp> { 108 public: 109 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 110 111 void initialize() { 112 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 113 // bounded as the rank is strictly decreasing. 114 setHasBoundedRewriteRecursion(); 115 } 116 117 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 118 PatternRewriter &rewriter) const override { 119 auto srcType = op.getSourceVectorType(); 120 auto dstType = op.getDestVectorType(); 121 122 if (op.offsets().getValue().empty()) 123 return failure(); 124 125 int64_t srcRank = srcType.getRank(); 126 int64_t dstRank = dstType.getRank(); 127 assert(dstRank >= srcRank); 128 if (dstRank != srcRank) 129 return failure(); 130 131 if (srcType == dstType) { 132 rewriter.replaceOp(op, op.source()); 133 return success(); 134 } 135 136 int64_t offset = 137 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 138 int64_t size = srcType.getShape().front(); 139 int64_t stride = 140 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 141 142 auto loc = op.getLoc(); 143 Value res = op.dest(); 144 145 if (srcRank == 1) { 146 int nSrc = srcType.getShape().front(); 147 int nDest = dstType.getShape().front(); 148 // 1. Scale source to destType so we can shufflevector them together. 149 SmallVector<int64_t> offsets(nDest, 0); 150 for (int64_t i = 0; i < nSrc; ++i) 151 offsets[i] = i; 152 Value scaledSource = 153 rewriter.create<ShuffleOp>(loc, op.source(), op.source(), offsets); 154 155 // 2. Create a mask where we take the value from scaledSource of dest 156 // depending on the offset. 157 offsets.clear(); 158 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) { 159 if (i < offset || i >= e || (i - offset) % stride != 0) 160 offsets.push_back(nDest + i); 161 else 162 offsets.push_back((i - offset) / stride); 163 } 164 165 // 3. Replace with a ShuffleOp. 166 rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.dest(), 167 offsets); 168 169 return success(); 170 } 171 172 // For each slice of the source vector along the most major dimension. 173 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 174 off += stride, ++idx) { 175 // 1. extract the proper subvector (or element) from source 176 Value extractedSource = extractOne(rewriter, loc, op.source(), idx); 177 if (extractedSource.getType().isa<VectorType>()) { 178 // 2. If we have a vector, extract the proper subvector from destination 179 // Otherwise we are at the element level and no need to recurse. 180 Value extractedDest = extractOne(rewriter, loc, op.dest(), off); 181 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 182 // smaller rank. 183 extractedSource = rewriter.create<InsertStridedSliceOp>( 184 loc, extractedSource, extractedDest, 185 getI64SubArray(op.offsets(), /* dropFront=*/1), 186 getI64SubArray(op.strides(), /* dropFront=*/1)); 187 } 188 // 4. Insert the extractedSource into the res vector. 189 res = insertOne(rewriter, loc, extractedSource, res, off); 190 } 191 192 rewriter.replaceOp(op, res); 193 return success(); 194 } 195 }; 196 197 /// Progressive lowering of ExtractStridedSliceOp to either: 198 /// 1. single offset extract as a direct vector::ShuffleOp. 199 /// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp + 200 /// InsertOp/InsertElementOp for the n-D case. 201 class VectorExtractStridedSliceOpRewritePattern 202 : public OpRewritePattern<ExtractStridedSliceOp> { 203 public: 204 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 205 206 void initialize() { 207 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 208 // is bounded as the rank is strictly decreasing. 209 setHasBoundedRewriteRecursion(); 210 } 211 212 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 213 PatternRewriter &rewriter) const override { 214 auto dstType = op.getType(); 215 216 assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); 217 218 int64_t offset = 219 op.offsets().getValue().front().cast<IntegerAttr>().getInt(); 220 int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt(); 221 int64_t stride = 222 op.strides().getValue().front().cast<IntegerAttr>().getInt(); 223 224 auto loc = op.getLoc(); 225 auto elemType = dstType.getElementType(); 226 assert(elemType.isSignlessIntOrIndexOrFloat()); 227 228 // Single offset can be more efficiently shuffled. 229 if (op.offsets().getValue().size() == 1) { 230 SmallVector<int64_t, 4> offsets; 231 offsets.reserve(size); 232 for (int64_t off = offset, e = offset + size * stride; off < e; 233 off += stride) 234 offsets.push_back(off); 235 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(), 236 op.vector(), 237 rewriter.getI64ArrayAttr(offsets)); 238 return success(); 239 } 240 241 // Extract/insert on a lower ranked extract strided slice op. 242 Value zero = rewriter.create<arith::ConstantOp>( 243 loc, elemType, rewriter.getZeroAttr(elemType)); 244 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 245 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 246 off += stride, ++idx) { 247 Value one = extractOne(rewriter, loc, op.vector(), off); 248 Value extracted = rewriter.create<ExtractStridedSliceOp>( 249 loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), 250 getI64SubArray(op.sizes(), /* dropFront=*/1), 251 getI64SubArray(op.strides(), /* dropFront=*/1)); 252 res = insertOne(rewriter, loc, extracted, res, idx); 253 } 254 rewriter.replaceOp(op, res); 255 return success(); 256 } 257 }; 258 259 /// Populate the given list with patterns that convert from Vector to LLVM. 260 void mlir::vector::populateVectorInsertExtractStridedSliceTransforms( 261 RewritePatternSet &patterns) { 262 patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern, 263 VectorInsertStridedSliceOpSameRankRewritePattern, 264 VectorExtractStridedSliceOpRewritePattern>( 265 patterns.getContext()); 266 } 267