1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/Utils/IndexingUtils.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 17 using namespace mlir; 18 using namespace mlir::vector; 19 20 // Helper that picks the proper sequence for inserting. 21 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 22 Value into, int64_t offset) { 23 auto vectorType = into.getType().cast<VectorType>(); 24 if (vectorType.getRank() > 1) 25 return rewriter.create<InsertOp>(loc, from, into, offset); 26 return rewriter.create<vector::InsertElementOp>( 27 loc, vectorType, from, into, 28 rewriter.create<arith::ConstantIndexOp>(loc, offset)); 29 } 30 31 // Helper that picks the proper sequence for extracting. 32 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 33 int64_t offset) { 34 auto vectorType = vector.getType().cast<VectorType>(); 35 if (vectorType.getRank() > 1) 36 return rewriter.create<ExtractOp>(loc, vector, offset); 37 return rewriter.create<vector::ExtractElementOp>( 38 loc, vectorType.getElementType(), vector, 39 rewriter.create<arith::ConstantIndexOp>(loc, offset)); 40 } 41 42 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 43 /// have different ranks. 44 /// 45 /// When ranks are different, InsertStridedSlice needs to extract a properly 46 /// ranked vector from the destination vector into which to insert. This pattern 47 /// only takes care of this extraction part and forwards the rest to 48 /// [ConvertSameRankInsertStridedSliceIntoShuffle]. 49 /// 50 /// For a k-D source and n-D destination vector (k < n), we emit: 51 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to 52 /// insert the k-D source. 53 /// 2. k-D -> (n-1)-D InsertStridedSlice op 54 /// 3. InsertOp that is the reverse of 1. 55 class DecomposeDifferentRankInsertStridedSlice 56 : public OpRewritePattern<InsertStridedSliceOp> { 57 public: 58 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 59 60 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 61 PatternRewriter &rewriter) const override { 62 auto srcType = op.getSourceVectorType(); 63 auto dstType = op.getDestVectorType(); 64 65 if (op.getOffsets().getValue().empty()) 66 return failure(); 67 68 auto loc = op.getLoc(); 69 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 70 assert(rankDiff >= 0); 71 if (rankDiff == 0) 72 return failure(); 73 74 int64_t rankRest = dstType.getRank() - rankDiff; 75 // Extract / insert the subvector of matching rank and InsertStridedSlice 76 // on it. 77 Value extracted = rewriter.create<ExtractOp>( 78 loc, op.getDest(), 79 getI64SubArray(op.getOffsets(), /*dropFront=*/0, 80 /*dropBack=*/rankRest)); 81 82 // A different pattern will kick in for InsertStridedSlice with matching 83 // ranks. 84 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 85 loc, op.getSource(), extracted, 86 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff), 87 getI64SubArray(op.getStrides(), /*dropFront=*/0)); 88 89 rewriter.replaceOpWithNewOp<InsertOp>( 90 op, stridedSliceInnerOp.getResult(), op.getDest(), 91 getI64SubArray(op.getOffsets(), /*dropFront=*/0, 92 /*dropBack=*/rankRest)); 93 return success(); 94 } 95 }; 96 97 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 98 /// have the same rank. For each outermost index in the slice: 99 /// begin end stride 100 /// [offset : offset+size*stride : stride] 101 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. 102 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D 103 /// 3. the destination subvector is inserted back in the proper place 104 /// 3. InsertOp that is the reverse of 1. 105 class ConvertSameRankInsertStridedSliceIntoShuffle 106 : public OpRewritePattern<InsertStridedSliceOp> { 107 public: 108 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 109 110 void initialize() { 111 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 112 // bounded as the rank is strictly decreasing. 113 setHasBoundedRewriteRecursion(); 114 } 115 116 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 117 PatternRewriter &rewriter) const override { 118 auto srcType = op.getSourceVectorType(); 119 auto dstType = op.getDestVectorType(); 120 121 if (op.getOffsets().getValue().empty()) 122 return failure(); 123 124 int64_t srcRank = srcType.getRank(); 125 int64_t dstRank = dstType.getRank(); 126 assert(dstRank >= srcRank); 127 if (dstRank != srcRank) 128 return failure(); 129 130 if (srcType == dstType) { 131 rewriter.replaceOp(op, op.getSource()); 132 return success(); 133 } 134 135 int64_t offset = 136 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt(); 137 int64_t size = srcType.getShape().front(); 138 int64_t stride = 139 op.getStrides().getValue().front().cast<IntegerAttr>().getInt(); 140 141 auto loc = op.getLoc(); 142 Value res = op.getDest(); 143 144 if (srcRank == 1) { 145 int nSrc = srcType.getShape().front(); 146 int nDest = dstType.getShape().front(); 147 // 1. Scale source to destType so we can shufflevector them together. 148 SmallVector<int64_t> offsets(nDest, 0); 149 for (int64_t i = 0; i < nSrc; ++i) 150 offsets[i] = i; 151 Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(), 152 op.getSource(), offsets); 153 154 // 2. Create a mask where we take the value from scaledSource of dest 155 // depending on the offset. 156 offsets.clear(); 157 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) { 158 if (i < offset || i >= e || (i - offset) % stride != 0) 159 offsets.push_back(nDest + i); 160 else 161 offsets.push_back((i - offset) / stride); 162 } 163 164 // 3. Replace with a ShuffleOp. 165 rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(), 166 offsets); 167 168 return success(); 169 } 170 171 // For each slice of the source vector along the most major dimension. 172 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 173 off += stride, ++idx) { 174 // 1. extract the proper subvector (or element) from source 175 Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx); 176 if (extractedSource.getType().isa<VectorType>()) { 177 // 2. If we have a vector, extract the proper subvector from destination 178 // Otherwise we are at the element level and no need to recurse. 179 Value extractedDest = extractOne(rewriter, loc, op.getDest(), off); 180 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 181 // smaller rank. 182 extractedSource = rewriter.create<InsertStridedSliceOp>( 183 loc, extractedSource, extractedDest, 184 getI64SubArray(op.getOffsets(), /* dropFront=*/1), 185 getI64SubArray(op.getStrides(), /* dropFront=*/1)); 186 } 187 // 4. Insert the extractedSource into the res vector. 188 res = insertOne(rewriter, loc, extractedSource, res, off); 189 } 190 191 rewriter.replaceOp(op, res); 192 return success(); 193 } 194 }; 195 196 /// RewritePattern for ExtractStridedSliceOp where source and destination 197 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp. 198 class Convert1DExtractStridedSliceIntoShuffle 199 : public OpRewritePattern<ExtractStridedSliceOp> { 200 public: 201 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 202 203 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 204 PatternRewriter &rewriter) const override { 205 auto dstType = op.getType(); 206 207 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); 208 209 int64_t offset = 210 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt(); 211 int64_t size = 212 op.getSizes().getValue().front().cast<IntegerAttr>().getInt(); 213 int64_t stride = 214 op.getStrides().getValue().front().cast<IntegerAttr>().getInt(); 215 216 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat()); 217 218 // Single offset can be more efficiently shuffled. 219 if (op.getOffsets().getValue().size() != 1) 220 return failure(); 221 222 SmallVector<int64_t, 4> offsets; 223 offsets.reserve(size); 224 for (int64_t off = offset, e = offset + size * stride; off < e; 225 off += stride) 226 offsets.push_back(off); 227 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(), 228 op.getVector(), 229 rewriter.getI64ArrayAttr(offsets)); 230 return success(); 231 } 232 }; 233 234 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D. 235 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower 236 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. 237 class DecomposeNDExtractStridedSlice 238 : public OpRewritePattern<ExtractStridedSliceOp> { 239 public: 240 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 241 242 void initialize() { 243 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 244 // is bounded as the rank is strictly decreasing. 245 setHasBoundedRewriteRecursion(); 246 } 247 248 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 249 PatternRewriter &rewriter) const override { 250 auto dstType = op.getType(); 251 252 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); 253 254 int64_t offset = 255 op.getOffsets().getValue().front().cast<IntegerAttr>().getInt(); 256 int64_t size = 257 op.getSizes().getValue().front().cast<IntegerAttr>().getInt(); 258 int64_t stride = 259 op.getStrides().getValue().front().cast<IntegerAttr>().getInt(); 260 261 auto elemType = dstType.getElementType(); 262 assert(elemType.isSignlessIntOrIndexOrFloat()); 263 264 // Single offset can be more efficiently shuffled. It's handled in 265 // Convert1DExtractStridedSliceIntoShuffle. 266 if (op.getOffsets().getValue().size() == 1) 267 return failure(); 268 269 // Extract/insert on a lower ranked extract strided slice op. 270 Value zero = rewriter.create<arith::ConstantOp>( 271 loc, elemType, rewriter.getZeroAttr(elemType)); 272 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 273 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 274 off += stride, ++idx) { 275 Value one = extractOne(rewriter, loc, op.getVector(), off); 276 Value extracted = rewriter.create<ExtractStridedSliceOp>( 277 loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1), 278 getI64SubArray(op.getSizes(), /* dropFront=*/1), 279 getI64SubArray(op.getStrides(), /* dropFront=*/1)); 280 res = insertOne(rewriter, loc, extracted, res, idx); 281 } 282 rewriter.replaceOp(op, res); 283 return success(); 284 } 285 }; 286 287 void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( 288 RewritePatternSet &patterns) { 289 patterns.add<DecomposeDifferentRankInsertStridedSlice, 290 DecomposeNDExtractStridedSlice>(patterns.getContext()); 291 } 292 293 /// Populate the given list with patterns that convert from Vector to LLVM. 294 void mlir::vector::populateVectorInsertExtractStridedSliceTransforms( 295 RewritePatternSet &patterns) { 296 populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns); 297 patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle, 298 Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext()); 299 } 300