1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/Utils/IndexingUtils.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::vector; 20 21 // Helper that picks the proper sequence for inserting. 22 static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, 23 Value into, int64_t offset) { 24 auto vectorType = cast<VectorType>(into.getType()); 25 if (vectorType.getRank() > 1) 26 return rewriter.create<InsertOp>(loc, from, into, offset); 27 return rewriter.create<vector::InsertElementOp>( 28 loc, vectorType, from, into, 29 rewriter.create<arith::ConstantIndexOp>(loc, offset)); 30 } 31 32 // Helper that picks the proper sequence for extracting. 33 static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, 34 int64_t offset) { 35 auto vectorType = cast<VectorType>(vector.getType()); 36 if (vectorType.getRank() > 1) 37 return rewriter.create<ExtractOp>(loc, vector, offset); 38 return rewriter.create<vector::ExtractElementOp>( 39 loc, vectorType.getElementType(), vector, 40 rewriter.create<arith::ConstantIndexOp>(loc, offset)); 41 } 42 43 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 44 /// have different ranks. 45 /// 46 /// When ranks are different, InsertStridedSlice needs to extract a properly 47 /// ranked vector from the destination vector into which to insert. This pattern 48 /// only takes care of this extraction part and forwards the rest to 49 /// [ConvertSameRankInsertStridedSliceIntoShuffle]. 50 /// 51 /// For a k-D source and n-D destination vector (k < n), we emit: 52 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to 53 /// insert the k-D source. 54 /// 2. k-D -> (n-1)-D InsertStridedSlice op 55 /// 3. InsertOp that is the reverse of 1. 56 class DecomposeDifferentRankInsertStridedSlice 57 : public OpRewritePattern<InsertStridedSliceOp> { 58 public: 59 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 60 61 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 62 PatternRewriter &rewriter) const override { 63 auto srcType = op.getSourceVectorType(); 64 auto dstType = op.getDestVectorType(); 65 66 if (op.getOffsets().getValue().empty()) 67 return failure(); 68 69 auto loc = op.getLoc(); 70 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 71 assert(rankDiff >= 0); 72 if (rankDiff == 0) 73 return failure(); 74 75 int64_t rankRest = dstType.getRank() - rankDiff; 76 // Extract / insert the subvector of matching rank and InsertStridedSlice 77 // on it. 78 Value extracted = rewriter.create<ExtractOp>( 79 loc, op.getDest(), 80 getI64SubArray(op.getOffsets(), /*dropFront=*/0, 81 /*dropBack=*/rankRest)); 82 83 // A different pattern will kick in for InsertStridedSlice with matching 84 // ranks. 85 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 86 loc, op.getSource(), extracted, 87 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff), 88 getI64SubArray(op.getStrides(), /*dropFront=*/0)); 89 90 rewriter.replaceOpWithNewOp<InsertOp>( 91 op, stridedSliceInnerOp.getResult(), op.getDest(), 92 getI64SubArray(op.getOffsets(), /*dropFront=*/0, 93 /*dropBack=*/rankRest)); 94 return success(); 95 } 96 }; 97 98 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 99 /// have the same rank. For each outermost index in the slice: 100 /// begin end stride 101 /// [offset : offset+size*stride : stride] 102 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. 103 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D 104 /// 3. the destination subvector is inserted back in the proper place 105 /// 3. InsertOp that is the reverse of 1. 106 class ConvertSameRankInsertStridedSliceIntoShuffle 107 : public OpRewritePattern<InsertStridedSliceOp> { 108 public: 109 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 110 111 void initialize() { 112 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 113 // bounded as the rank is strictly decreasing. 114 setHasBoundedRewriteRecursion(); 115 } 116 117 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 118 PatternRewriter &rewriter) const override { 119 auto srcType = op.getSourceVectorType(); 120 auto dstType = op.getDestVectorType(); 121 122 if (op.getOffsets().getValue().empty()) 123 return failure(); 124 125 int64_t srcRank = srcType.getRank(); 126 int64_t dstRank = dstType.getRank(); 127 assert(dstRank >= srcRank); 128 if (dstRank != srcRank) 129 return failure(); 130 131 if (srcType == dstType) { 132 rewriter.replaceOp(op, op.getSource()); 133 return success(); 134 } 135 136 int64_t offset = 137 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt(); 138 int64_t size = srcType.getShape().front(); 139 int64_t stride = 140 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt(); 141 142 auto loc = op.getLoc(); 143 Value res = op.getDest(); 144 145 if (srcRank == 1) { 146 int nSrc = srcType.getShape().front(); 147 int nDest = dstType.getShape().front(); 148 // 1. Scale source to destType so we can shufflevector them together. 149 SmallVector<int64_t> offsets(nDest, 0); 150 for (int64_t i = 0; i < nSrc; ++i) 151 offsets[i] = i; 152 Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(), 153 op.getSource(), offsets); 154 155 // 2. Create a mask where we take the value from scaledSource of dest 156 // depending on the offset. 157 offsets.clear(); 158 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) { 159 if (i < offset || i >= e || (i - offset) % stride != 0) 160 offsets.push_back(nDest + i); 161 else 162 offsets.push_back((i - offset) / stride); 163 } 164 165 // 3. Replace with a ShuffleOp. 166 rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(), 167 offsets); 168 169 return success(); 170 } 171 172 // For each slice of the source vector along the most major dimension. 173 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 174 off += stride, ++idx) { 175 // 1. extract the proper subvector (or element) from source 176 Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx); 177 if (isa<VectorType>(extractedSource.getType())) { 178 // 2. If we have a vector, extract the proper subvector from destination 179 // Otherwise we are at the element level and no need to recurse. 180 Value extractedDest = extractOne(rewriter, loc, op.getDest(), off); 181 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 182 // smaller rank. 183 extractedSource = rewriter.create<InsertStridedSliceOp>( 184 loc, extractedSource, extractedDest, 185 getI64SubArray(op.getOffsets(), /* dropFront=*/1), 186 getI64SubArray(op.getStrides(), /* dropFront=*/1)); 187 } 188 // 4. Insert the extractedSource into the res vector. 189 res = insertOne(rewriter, loc, extractedSource, res, off); 190 } 191 192 rewriter.replaceOp(op, res); 193 return success(); 194 } 195 }; 196 197 /// RewritePattern for ExtractStridedSliceOp where source and destination 198 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp. 199 class Convert1DExtractStridedSliceIntoShuffle 200 : public OpRewritePattern<ExtractStridedSliceOp> { 201 public: 202 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 203 204 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 205 PatternRewriter &rewriter) const override { 206 auto dstType = op.getType(); 207 208 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); 209 210 int64_t offset = 211 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt(); 212 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt(); 213 int64_t stride = 214 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt(); 215 216 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat()); 217 218 // Single offset can be more efficiently shuffled. 219 if (op.getOffsets().getValue().size() != 1) 220 return failure(); 221 222 SmallVector<int64_t, 4> offsets; 223 offsets.reserve(size); 224 for (int64_t off = offset, e = offset + size * stride; off < e; 225 off += stride) 226 offsets.push_back(off); 227 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(), 228 op.getVector(), offsets); 229 return success(); 230 } 231 }; 232 233 /// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops 234 /// to extract each element from the source, and then a chain of Insert ops 235 /// to insert to the target vector. 236 class Convert1DExtractStridedSliceIntoExtractInsertChain final 237 : public OpRewritePattern<ExtractStridedSliceOp> { 238 public: 239 Convert1DExtractStridedSliceIntoExtractInsertChain( 240 MLIRContext *context, 241 std::function<bool(ExtractStridedSliceOp)> controlFn, 242 PatternBenefit benefit) 243 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} 244 245 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 246 PatternRewriter &rewriter) const override { 247 if (controlFn && !controlFn(op)) 248 return failure(); 249 250 // Only handle 1-D cases. 251 if (op.getOffsets().getValue().size() != 1) 252 return failure(); 253 254 int64_t offset = 255 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt(); 256 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt(); 257 int64_t stride = 258 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt(); 259 260 Location loc = op.getLoc(); 261 SmallVector<Value> elements; 262 elements.reserve(size); 263 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride) 264 elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i)); 265 266 Value result = rewriter.create<arith::ConstantOp>( 267 loc, rewriter.getZeroAttr(op.getType())); 268 for (int64_t i = 0; i < size; ++i) 269 result = rewriter.create<InsertOp>(loc, elements[i], result, i); 270 271 rewriter.replaceOp(op, result); 272 return success(); 273 } 274 275 private: 276 std::function<bool(ExtractStridedSliceOp)> controlFn; 277 }; 278 279 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D. 280 /// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower 281 /// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. 282 class DecomposeNDExtractStridedSlice 283 : public OpRewritePattern<ExtractStridedSliceOp> { 284 public: 285 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 286 287 void initialize() { 288 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 289 // is bounded as the rank is strictly decreasing. 290 setHasBoundedRewriteRecursion(); 291 } 292 293 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 294 PatternRewriter &rewriter) const override { 295 auto dstType = op.getType(); 296 297 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); 298 299 int64_t offset = 300 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt(); 301 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt(); 302 int64_t stride = 303 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt(); 304 305 auto loc = op.getLoc(); 306 auto elemType = dstType.getElementType(); 307 assert(elemType.isSignlessIntOrIndexOrFloat()); 308 309 // Single offset can be more efficiently shuffled. It's handled in 310 // Convert1DExtractStridedSliceIntoShuffle. 311 if (op.getOffsets().getValue().size() == 1) 312 return failure(); 313 314 // Extract/insert on a lower ranked extract strided slice op. 315 Value zero = rewriter.create<arith::ConstantOp>( 316 loc, elemType, rewriter.getZeroAttr(elemType)); 317 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 318 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 319 off += stride, ++idx) { 320 Value one = extractOne(rewriter, loc, op.getVector(), off); 321 Value extracted = rewriter.create<ExtractStridedSliceOp>( 322 loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1), 323 getI64SubArray(op.getSizes(), /* dropFront=*/1), 324 getI64SubArray(op.getStrides(), /* dropFront=*/1)); 325 res = insertOne(rewriter, loc, extracted, res, idx); 326 } 327 rewriter.replaceOp(op, res); 328 return success(); 329 } 330 }; 331 332 /// Pattern to rewrite simple cases of N-D extract_strided_slice, where the 333 /// slice is contiguous, into extract and shape_cast. 334 class ContiguousExtractStridedSliceToExtract final 335 : public OpRewritePattern<ExtractStridedSliceOp> { 336 public: 337 using OpRewritePattern::OpRewritePattern; 338 339 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 340 PatternRewriter &rewriter) const override { 341 if (op.hasNonUnitStrides()) { 342 return failure(); 343 } 344 Value source = op.getOperand(); 345 auto sourceType = cast<VectorType>(source.getType()); 346 if (sourceType.isScalable()) { 347 return failure(); 348 } 349 350 // Compute the number of offsets to pass to ExtractOp::build. That is the 351 // difference between the source rank and the desired slice rank. We walk 352 // the dimensions from innermost out, and stop when the next slice dimension 353 // is not full-size. 354 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes()); 355 int numOffsets; 356 for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) { 357 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) { 358 break; 359 } 360 } 361 362 // If not even the inner-most dimension is full-size, this op can't be 363 // rewritten as an ExtractOp. 364 if (numOffsets == sourceType.getRank()) { 365 return failure(); 366 } 367 368 // Avoid generating slices that have unit outer dimensions. The shape_cast 369 // op that we create below would take bad generic fallback patterns 370 // (ShapeCastOpRewritePattern). 371 while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) { 372 ++numOffsets; 373 } 374 375 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets()); 376 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets); 377 Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source, 378 extractOffsets); 379 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract); 380 return success(); 381 } 382 }; 383 384 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( 385 RewritePatternSet &patterns, PatternBenefit benefit) { 386 patterns.add<DecomposeDifferentRankInsertStridedSlice, 387 DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit); 388 } 389 390 void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns( 391 RewritePatternSet &patterns, PatternBenefit benefit) { 392 patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(), 393 benefit); 394 } 395 396 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns( 397 RewritePatternSet &patterns, 398 std::function<bool(ExtractStridedSliceOp)> controlFn, 399 PatternBenefit benefit) { 400 patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>( 401 patterns.getContext(), std::move(controlFn), benefit); 402 } 403 404 /// Populate the given list with patterns that convert from Vector to LLVM. 405 void vector::populateVectorInsertExtractStridedSliceTransforms( 406 RewritePatternSet &patterns, PatternBenefit benefit) { 407 populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns, 408 benefit); 409 patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle, 410 Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(), 411 benefit); 412 } 413