1 //===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arith/IR/Arith.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/Dialect/Utils/IndexingUtils.h" 12 #include "mlir/Dialect/Vector/IR/VectorOps.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::vector; 20 21 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 22 /// have different ranks. 23 /// 24 /// When ranks are different, InsertStridedSlice needs to extract a properly 25 /// ranked vector from the destination vector into which to insert. This pattern 26 /// only takes care of this extraction part and forwards the rest to 27 /// [ConvertSameRankInsertStridedSliceIntoShuffle]. 28 /// 29 /// For a k-D source and n-D destination vector (k < n), we emit: 30 /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to 31 /// insert the k-D source. 32 /// 2. k-D -> (n-1)-D InsertStridedSlice op 33 /// 3. InsertOp that is the reverse of 1. 34 class DecomposeDifferentRankInsertStridedSlice 35 : public OpRewritePattern<InsertStridedSliceOp> { 36 public: 37 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 38 39 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 40 PatternRewriter &rewriter) const override { 41 auto srcType = op.getSourceVectorType(); 42 auto dstType = op.getDestVectorType(); 43 44 if (op.getOffsets().getValue().empty()) 45 return failure(); 46 47 auto loc = op.getLoc(); 48 int64_t rankDiff = dstType.getRank() - srcType.getRank(); 49 assert(rankDiff >= 0); 50 if (rankDiff == 0) 51 return failure(); 52 53 int64_t rankRest = dstType.getRank() - rankDiff; 54 // Extract / insert the subvector of matching rank and InsertStridedSlice 55 // on it. 56 Value extracted = rewriter.create<ExtractOp>( 57 loc, op.getDest(), 58 getI64SubArray(op.getOffsets(), /*dropFront=*/0, 59 /*dropBack=*/rankRest)); 60 61 // A different pattern will kick in for InsertStridedSlice with matching 62 // ranks. 63 auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>( 64 loc, op.getSource(), extracted, 65 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff), 66 getI64SubArray(op.getStrides(), /*dropFront=*/0)); 67 68 rewriter.replaceOpWithNewOp<InsertOp>( 69 op, stridedSliceInnerOp.getResult(), op.getDest(), 70 getI64SubArray(op.getOffsets(), /*dropFront=*/0, 71 /*dropBack=*/rankRest)); 72 return success(); 73 } 74 }; 75 76 /// RewritePattern for InsertStridedSliceOp where source and destination vectors 77 /// have the same rank. For each outermost index in the slice: 78 /// begin end stride 79 /// [offset : offset+size*stride : stride] 80 /// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector. 81 /// 2. InsertStridedSlice (k-1)-D into (n-1)-D 82 /// 3. the destination subvector is inserted back in the proper place 83 /// 3. InsertOp that is the reverse of 1. 84 class ConvertSameRankInsertStridedSliceIntoShuffle 85 : public OpRewritePattern<InsertStridedSliceOp> { 86 public: 87 using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; 88 89 void initialize() { 90 // This pattern creates recursive InsertStridedSliceOp, but the recursion is 91 // bounded as the rank is strictly decreasing. 92 setHasBoundedRewriteRecursion(); 93 } 94 95 LogicalResult matchAndRewrite(InsertStridedSliceOp op, 96 PatternRewriter &rewriter) const override { 97 auto srcType = op.getSourceVectorType(); 98 auto dstType = op.getDestVectorType(); 99 100 if (op.getOffsets().getValue().empty()) 101 return failure(); 102 103 int64_t srcRank = srcType.getRank(); 104 int64_t dstRank = dstType.getRank(); 105 assert(dstRank >= srcRank); 106 if (dstRank != srcRank) 107 return failure(); 108 109 if (srcType == dstType) { 110 rewriter.replaceOp(op, op.getSource()); 111 return success(); 112 } 113 114 int64_t offset = 115 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt(); 116 int64_t size = srcType.getShape().front(); 117 int64_t stride = 118 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt(); 119 120 auto loc = op.getLoc(); 121 Value res = op.getDest(); 122 123 if (srcRank == 1) { 124 int nSrc = srcType.getShape().front(); 125 int nDest = dstType.getShape().front(); 126 // 1. Scale source to destType so we can shufflevector them together. 127 SmallVector<int64_t> offsets(nDest, 0); 128 for (int64_t i = 0; i < nSrc; ++i) 129 offsets[i] = i; 130 Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(), 131 op.getSource(), offsets); 132 133 // 2. Create a mask where we take the value from scaledSource of dest 134 // depending on the offset. 135 offsets.clear(); 136 for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) { 137 if (i < offset || i >= e || (i - offset) % stride != 0) 138 offsets.push_back(nDest + i); 139 else 140 offsets.push_back((i - offset) / stride); 141 } 142 143 // 3. Replace with a ShuffleOp. 144 rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(), 145 offsets); 146 147 return success(); 148 } 149 150 // For each slice of the source vector along the most major dimension. 151 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 152 off += stride, ++idx) { 153 // 1. extract the proper subvector (or element) from source 154 Value extractedSource = 155 rewriter.create<ExtractOp>(loc, op.getSource(), idx); 156 if (isa<VectorType>(extractedSource.getType())) { 157 // 2. If we have a vector, extract the proper subvector from destination 158 // Otherwise we are at the element level and no need to recurse. 159 Value extractedDest = 160 rewriter.create<ExtractOp>(loc, op.getDest(), off); 161 // 3. Reduce the problem to lowering a new InsertStridedSlice op with 162 // smaller rank. 163 extractedSource = rewriter.create<InsertStridedSliceOp>( 164 loc, extractedSource, extractedDest, 165 getI64SubArray(op.getOffsets(), /* dropFront=*/1), 166 getI64SubArray(op.getStrides(), /* dropFront=*/1)); 167 } 168 // 4. Insert the extractedSource into the res vector. 169 res = rewriter.create<InsertOp>(loc, extractedSource, res, off); 170 } 171 172 rewriter.replaceOp(op, res); 173 return success(); 174 } 175 }; 176 177 /// RewritePattern for ExtractStridedSliceOp where source and destination 178 /// vectors are 1-D. For such cases, we can lower it to a ShuffleOp. 179 class Convert1DExtractStridedSliceIntoShuffle 180 : public OpRewritePattern<ExtractStridedSliceOp> { 181 public: 182 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 183 184 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 185 PatternRewriter &rewriter) const override { 186 auto dstType = op.getType(); 187 188 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); 189 190 int64_t offset = 191 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt(); 192 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt(); 193 int64_t stride = 194 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt(); 195 196 assert(dstType.getElementType().isSignlessIntOrIndexOrFloat()); 197 198 // Single offset can be more efficiently shuffled. 199 if (op.getOffsets().getValue().size() != 1) 200 return failure(); 201 202 SmallVector<int64_t, 4> offsets; 203 offsets.reserve(size); 204 for (int64_t off = offset, e = offset + size * stride; off < e; 205 off += stride) 206 offsets.push_back(off); 207 rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(), 208 op.getVector(), offsets); 209 return success(); 210 } 211 }; 212 213 /// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops 214 /// to extract each element from the source, and then a chain of Insert ops 215 /// to insert to the target vector. 216 class Convert1DExtractStridedSliceIntoExtractInsertChain final 217 : public OpRewritePattern<ExtractStridedSliceOp> { 218 public: 219 Convert1DExtractStridedSliceIntoExtractInsertChain( 220 MLIRContext *context, 221 std::function<bool(ExtractStridedSliceOp)> controlFn, 222 PatternBenefit benefit) 223 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} 224 225 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 226 PatternRewriter &rewriter) const override { 227 if (controlFn && !controlFn(op)) 228 return failure(); 229 230 // Only handle 1-D cases. 231 if (op.getOffsets().getValue().size() != 1) 232 return failure(); 233 234 int64_t offset = 235 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt(); 236 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt(); 237 int64_t stride = 238 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt(); 239 240 Location loc = op.getLoc(); 241 SmallVector<Value> elements; 242 elements.reserve(size); 243 for (int64_t i = offset, e = offset + size * stride; i < e; i += stride) 244 elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i)); 245 246 Value result = rewriter.create<arith::ConstantOp>( 247 loc, rewriter.getZeroAttr(op.getType())); 248 for (int64_t i = 0; i < size; ++i) 249 result = rewriter.create<InsertOp>(loc, elements[i], result, i); 250 251 rewriter.replaceOp(op, result); 252 return success(); 253 } 254 255 private: 256 std::function<bool(ExtractStridedSliceOp)> controlFn; 257 }; 258 259 /// RewritePattern for ExtractStridedSliceOp where the source vector is n-D. 260 /// For such cases, we can rewrite it to ExtractOp + lower rank 261 /// ExtractStridedSliceOp + InsertOp for the n-D case. 262 class DecomposeNDExtractStridedSlice 263 : public OpRewritePattern<ExtractStridedSliceOp> { 264 public: 265 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; 266 267 void initialize() { 268 // This pattern creates recursive ExtractStridedSliceOp, but the recursion 269 // is bounded as the rank is strictly decreasing. 270 setHasBoundedRewriteRecursion(); 271 } 272 273 LogicalResult matchAndRewrite(ExtractStridedSliceOp op, 274 PatternRewriter &rewriter) const override { 275 auto dstType = op.getType(); 276 277 assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); 278 279 int64_t offset = 280 cast<IntegerAttr>(op.getOffsets().getValue().front()).getInt(); 281 int64_t size = cast<IntegerAttr>(op.getSizes().getValue().front()).getInt(); 282 int64_t stride = 283 cast<IntegerAttr>(op.getStrides().getValue().front()).getInt(); 284 285 auto loc = op.getLoc(); 286 auto elemType = dstType.getElementType(); 287 assert(elemType.isSignlessIntOrIndexOrFloat()); 288 289 // Single offset can be more efficiently shuffled. It's handled in 290 // Convert1DExtractStridedSliceIntoShuffle. 291 if (op.getOffsets().getValue().size() == 1) 292 return failure(); 293 294 // Extract/insert on a lower ranked extract strided slice op. 295 Value zero = rewriter.create<arith::ConstantOp>( 296 loc, elemType, rewriter.getZeroAttr(elemType)); 297 Value res = rewriter.create<SplatOp>(loc, dstType, zero); 298 for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; 299 off += stride, ++idx) { 300 Value one = rewriter.create<ExtractOp>(loc, op.getVector(), off); 301 Value extracted = rewriter.create<ExtractStridedSliceOp>( 302 loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1), 303 getI64SubArray(op.getSizes(), /* dropFront=*/1), 304 getI64SubArray(op.getStrides(), /* dropFront=*/1)); 305 res = rewriter.create<InsertOp>(loc, extracted, res, idx); 306 } 307 rewriter.replaceOp(op, res); 308 return success(); 309 } 310 }; 311 312 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( 313 RewritePatternSet &patterns, PatternBenefit benefit) { 314 patterns.add<DecomposeDifferentRankInsertStridedSlice, 315 DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit); 316 } 317 318 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns( 319 RewritePatternSet &patterns, 320 std::function<bool(ExtractStridedSliceOp)> controlFn, 321 PatternBenefit benefit) { 322 patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>( 323 patterns.getContext(), std::move(controlFn), benefit); 324 } 325 326 /// Populate the given list with patterns that convert from Vector to LLVM. 327 void vector::populateVectorInsertExtractStridedSliceTransforms( 328 RewritePatternSet &patterns, PatternBenefit benefit) { 329 populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns, 330 benefit); 331 patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle, 332 Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(), 333 benefit); 334 } 335