xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (revision 5262865aac683b72f3e66de7a122e0c455ab6b9b)
1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to do vector unrolling and vector distribution.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Utils/IndexingUtils.h"
15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
16 #include "mlir/IR/ImplicitLocOpBuilder.h"
17 #include "mlir/Interfaces/VectorInterfaces.h"
18 #include "llvm/ADT/MapVector.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/Debug.h"
21 #include <numeric>
22 #include <optional>
23 
24 #define DEBUG_TYPE "vector-unroll"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
27 
28 using namespace mlir;
29 using namespace mlir::vector;
30 
31 /// Compute the indices of the slice `index` for a tranfer op.
32 static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
33                                                ArrayRef<Value> indices,
34                                                AffineMap permutationMap,
35                                                Location loc,
36                                                OpBuilder &builder) {
37   MLIRContext *ctx = builder.getContext();
38   auto isBroadcast = [](AffineExpr expr) {
39     if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
40       return constExpr.getValue() == 0;
41     return false;
42   };
43   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
44   SmallVector<Value> slicedIndices(indices);
45   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
46     if (isBroadcast(dim.value()))
47       continue;
48     unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
49     auto expr = getAffineDimExpr(0, builder.getContext()) +
50                 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
51     auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
52     slicedIndices[pos] =
53         builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
54   }
55   return slicedIndices;
56 }
57 
58 // Clones `op` into a new operations that takes `operands` and returns
59 // `resultTypes`.
60 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
61                                               Operation *op,
62                                               ArrayRef<Value> operands,
63                                               ArrayRef<Type> resultTypes) {
64   return builder.create(loc, op->getName().getIdentifier(), operands,
65                         resultTypes, op->getAttrs());
66 }
67 
68 /// Return the target shape for unrolling for the given `op`. Return
69 /// std::nullopt if the op shouldn't be or cannot be unrolled.
70 static std::optional<SmallVector<int64_t>>
71 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
72   LDBG("");
73   LDBG("Get unroll shape for op " << op->getName().getStringRef());
74   if (options.filterConstraint && failed(options.filterConstraint(op))) {
75     LDBG("--no filter constraint -> BAIL");
76     return std::nullopt;
77   }
78   assert(options.nativeShape &&
79          "vector unrolling expects the native shape or native"
80          "shape call back function to be set");
81   auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
82   if (!unrollableVectorOp) {
83     LDBG("--not an unrollable op -> BAIL");
84     return std::nullopt;
85   }
86   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
87   if (!maybeUnrollShape) {
88     LDBG("--could not get shape of op " << *op << " -> BAIL");
89     return std::nullopt;
90   }
91   LLVM_DEBUG(
92       llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: ");
93       llvm::dbgs() << "\n";);
94 
95   std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
96   if (!targetShape) {
97     LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
98     return std::nullopt;
99   }
100   LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: ");
101              llvm::dbgs() << "\n";);
102 
103   auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
104   if (!maybeShapeRatio) {
105     LDBG("--could not compute integral shape ratio -> BAIL");
106     return std::nullopt;
107   }
108   if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
109     LDBG("--no unrolling needed -> SKIP");
110     return std::nullopt;
111   }
112   LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
113   return targetShape;
114 }
115 
116 static SmallVector<int64_t>
117 getUnrollOrder(unsigned numLoops, Operation *op,
118                const vector::UnrollVectorOptions &options) {
119   SmallVector<int64_t> loopOrder =
120       llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
121   if (options.traversalOrderCallback != nullptr) {
122     std::optional<SmallVector<int64_t>> order =
123         options.traversalOrderCallback(op);
124     if (order) {
125       loopOrder = std::move(*order);
126     }
127   }
128   return loopOrder;
129 }
130 
131 namespace {
132 
133 struct UnrollTransferReadPattern
134     : public OpRewritePattern<vector::TransferReadOp> {
135   UnrollTransferReadPattern(MLIRContext *context,
136                             const vector::UnrollVectorOptions &options,
137                             PatternBenefit benefit = 1)
138       : OpRewritePattern<vector::TransferReadOp>(context, benefit),
139         options(options) {}
140 
141   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
142                                 PatternRewriter &rewriter) const override {
143     // TODO: support 0-d corner case.
144     if (readOp.getTransferRank() == 0)
145       return failure();
146     if (readOp.getMask())
147       return failure();
148     auto targetShape = getTargetShape(options, readOp);
149     if (!targetShape)
150       return failure();
151     auto sourceVectorType = readOp.getVectorType();
152     SmallVector<int64_t> strides(targetShape->size(), 1);
153     Location loc = readOp.getLoc();
154     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
155 
156     // Prepare the result vector;
157     Value result = rewriter.create<arith::ConstantOp>(
158         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
159     auto targetType =
160         VectorType::get(*targetShape, sourceVectorType.getElementType());
161     SmallVector<Value> originalIndices(readOp.getIndices().begin(),
162                                        readOp.getIndices().end());
163     SmallVector<int64_t> loopOrder =
164         getUnrollOrder(originalSize.size(), readOp, options);
165     for (SmallVector<int64_t> elementOffsets :
166          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
167       SmallVector<Value> indices =
168           sliceTransferIndices(elementOffsets, originalIndices,
169                                readOp.getPermutationMap(), loc, rewriter);
170       auto slicedRead = rewriter.create<vector::TransferReadOp>(
171           loc, targetType, readOp.getSource(), indices,
172           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
173           readOp.getInBoundsAttr());
174 
175       result = rewriter.create<vector::InsertStridedSliceOp>(
176           loc, slicedRead, result, elementOffsets, strides);
177     }
178     rewriter.replaceOp(readOp, result);
179     return success();
180   }
181 
182 private:
183   vector::UnrollVectorOptions options;
184 };
185 
186 struct UnrollTransferWritePattern
187     : public OpRewritePattern<vector::TransferWriteOp> {
188   UnrollTransferWritePattern(MLIRContext *context,
189                              const vector::UnrollVectorOptions &options,
190                              PatternBenefit benefit = 1)
191       : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
192         options(options) {}
193 
194   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
195                                 PatternRewriter &rewriter) const override {
196     // TODO: support 0-d corner case.
197     if (writeOp.getTransferRank() == 0)
198       return failure();
199 
200     if (writeOp.getMask())
201       return failure();
202     auto targetShape = getTargetShape(options, writeOp);
203     if (!targetShape)
204       return failure();
205     auto sourceVectorType = writeOp.getVectorType();
206     SmallVector<int64_t> strides(targetShape->size(), 1);
207     Location loc = writeOp.getLoc();
208     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
209     SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
210                                        writeOp.getIndices().end());
211     SmallVector<int64_t> loopOrder =
212         getUnrollOrder(originalSize.size(), writeOp, options);
213     Value resultTensor;
214     for (SmallVector<int64_t> elementOffsets :
215          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
216       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
217           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
218       SmallVector<Value> indices =
219           sliceTransferIndices(elementOffsets, originalIndices,
220                                writeOp.getPermutationMap(), loc, rewriter);
221       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
222           loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
223           indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
224       // For the tensor case update the destination for the next transfer write.
225       if (!slicedWrite->getResults().empty())
226         resultTensor = slicedWrite->getResult(0);
227     }
228     if (resultTensor)
229       rewriter.replaceOp(writeOp, resultTensor);
230     else
231       rewriter.eraseOp(writeOp);
232     return success();
233   }
234 
235 private:
236   vector::UnrollVectorOptions options;
237 };
238 
239 struct OffsetMapInfo {
240   static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
241 
242   static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
243 
244   static unsigned getHashValue(const SmallVector<int64_t> &v) {
245     return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
246   }
247 
248   static bool isEqual(const SmallVector<int64_t> &lhs,
249                       const SmallVector<int64_t> &rhs) {
250     return lhs == rhs;
251   }
252 };
253 
254 struct UnrollContractionPattern
255     : public OpRewritePattern<vector::ContractionOp> {
256   UnrollContractionPattern(MLIRContext *context,
257                            const vector::UnrollVectorOptions &options,
258                            PatternBenefit benefit = 1)
259       : OpRewritePattern<vector::ContractionOp>(context, benefit),
260         options(options) {}
261 
262   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
263                                 PatternRewriter &rewriter) const override {
264     auto targetShape = getTargetShape(options, contractOp);
265     if (!targetShape)
266       return failure();
267     auto dstVecType = cast<VectorType>(contractOp.getResultType());
268     SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
269 
270     Location loc = contractOp.getLoc();
271     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
272     AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
273     llvm::MapVector<
274         SmallVector<int64_t>, Value,
275         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
276         accCache;
277 
278     SmallVector<int64_t> loopOrder = getUnrollOrder(
279         contractOp.getIteratorTypes().size(), contractOp, options);
280 
281     for (SmallVector<int64_t> offsets :
282          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
283       SmallVector<Value> slicesOperands(contractOp.getNumOperands());
284 
285       // Helper to compute the new shape of each operand and extract the slice.
286       auto extractOperand = [&](unsigned index, Value operand,
287                                 AffineMap permutationMap,
288                                 ArrayRef<int64_t> operandOffets) {
289         SmallVector<int64_t> operandShape = applyPermutationMap(
290             permutationMap, ArrayRef<int64_t>(*targetShape));
291         SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
292         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
293             loc, operand, operandOffets, operandShape, operandStrides);
294       };
295 
296       // Extract the new lhs operand.
297       AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
298       SmallVector<int64_t> lhsOffets =
299           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
300       extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
301 
302       // Extract the new rhs operand.
303       AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
304       SmallVector<int64_t> rhsOffets =
305           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
306       extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
307 
308       AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
309       SmallVector<int64_t> accOffets =
310           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
311       // If a version of the accumulator has already been computed, use it
312       // otherwise extract the first version from the original operand.
313       auto *accIt = accCache.find(accOffets);
314       if (accIt != accCache.end())
315         slicesOperands[2] = accIt->second;
316       else
317         extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
318 
319       SmallVector<int64_t> dstShape =
320           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
321       auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
322       Operation *newOp = cloneOpWithOperandsAndTypes(
323           rewriter, loc, contractOp, slicesOperands, targetType);
324 
325       SmallVector<int64_t> dstOffets =
326           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
327       // Save the accumulated value untill all the loops are unrolled since
328       // reduction loop keep updating the accumulator.
329       accCache[dstOffets] = newOp->getResult(0);
330     }
331     // Assemble back the accumulator into a single vector.
332     Value result = rewriter.create<arith::ConstantOp>(
333         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
334     for (const auto &it : accCache) {
335       SmallVector<int64_t> dstStrides(it.first.size(), 1);
336       result = rewriter.create<vector::InsertStridedSliceOp>(
337           loc, it.second, result, it.first, dstStrides);
338     }
339     rewriter.replaceOp(contractOp, result);
340     return success();
341   }
342 
343 private:
344   vector::UnrollVectorOptions options;
345 };
346 
347 struct UnrollMultiReductionPattern
348     : public OpRewritePattern<vector::MultiDimReductionOp> {
349   UnrollMultiReductionPattern(MLIRContext *context,
350                               const vector::UnrollVectorOptions &options,
351                               PatternBenefit benefit = 1)
352       : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
353         options(options) {}
354 
355   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
356                                 PatternRewriter &rewriter) const override {
357     std::optional<SmallVector<int64_t>> targetShape =
358         getTargetShape(options, reductionOp);
359     if (!targetShape)
360       return failure();
361     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
362     llvm::MapVector<
363         SmallVector<int64_t>, Value,
364         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
365         accCache;
366     Location loc = reductionOp.getLoc();
367 
368     // Stride of the ratios, this gives us the offsets of sliceCount in a basis
369     // of multiples of the targetShape.
370     for (SmallVector<int64_t> offsets :
371          StaticTileOffsetRange(originalSize, *targetShape)) {
372       SmallVector<Value> operands;
373       SmallVector<int64_t> operandStrides(offsets.size(), 1);
374       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
375           loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
376       operands.push_back(slicedOperand);
377       SmallVector<int64_t> dstShape;
378       SmallVector<int64_t> destOffset;
379       for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
380         if (!reductionOp.isReducedDim(i)) {
381           destOffset.push_back(offsets[i]);
382           dstShape.push_back((*targetShape)[i]);
383         }
384       }
385       Value acc;
386       SmallVector<int64_t> accStrides(destOffset.size(), 1);
387       // If a version of the accumulator has already been computed, use it
388       // otherwise extract the first version from the original operand.
389       auto *accIt = accCache.find(destOffset);
390       if (accIt != accCache.end())
391         acc = accIt->second;
392       else
393         acc = rewriter.create<vector::ExtractStridedSliceOp>(
394             loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
395       operands.push_back(acc);
396       auto targetType = VectorType::get(
397           dstShape, reductionOp.getSourceVectorType().getElementType());
398       Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
399                                                      operands, targetType);
400       Value result = newOp->getResult(0);
401       accCache[destOffset] = result;
402     }
403     // Assemble back the accumulator into a single vector.
404     Value result = rewriter.create<arith::ConstantOp>(
405         loc, reductionOp.getDestType(),
406         rewriter.getZeroAttr(reductionOp.getDestType()));
407     for (const auto &it : accCache) {
408       SmallVector<int64_t> dstStrides(it.first.size(), 1);
409       result = rewriter.create<vector::InsertStridedSliceOp>(
410           loc, it.second, result, it.first, dstStrides);
411     }
412     rewriter.replaceOp(reductionOp, result);
413     return success();
414   }
415 
416 private:
417   vector::UnrollVectorOptions options;
418 };
419 
420 struct UnrollElementwisePattern : public RewritePattern {
421   UnrollElementwisePattern(MLIRContext *context,
422                            const vector::UnrollVectorOptions &options,
423                            PatternBenefit benefit = 1)
424       : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
425         options(options) {}
426 
427   LogicalResult matchAndRewrite(Operation *op,
428                                 PatternRewriter &rewriter) const override {
429     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
430       return failure();
431     auto targetShape = getTargetShape(options, op);
432     if (!targetShape)
433       return failure();
434     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
435     SmallVector<int64_t> originalSize =
436         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
437     Location loc = op->getLoc();
438     // Prepare the result vector.
439     Value result = rewriter.create<arith::ConstantOp>(
440         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
441     SmallVector<int64_t> strides(targetShape->size(), 1);
442     VectorType newVecType =
443         VectorType::get(*targetShape, dstVecType.getElementType());
444 
445     // Create the unrolled computation.
446     for (SmallVector<int64_t> offsets :
447          StaticTileOffsetRange(originalSize, *targetShape)) {
448       SmallVector<Value> extractOperands;
449       for (OpOperand &operand : op->getOpOperands()) {
450         auto vecType = dyn_cast<VectorType>(operand.get().getType());
451         if (!vecType) {
452           extractOperands.push_back(operand.get());
453           continue;
454         }
455         extractOperands.push_back(
456             rewriter.create<vector::ExtractStridedSliceOp>(
457                 loc, operand.get(), offsets, *targetShape, strides));
458       }
459       Operation *newOp = cloneOpWithOperandsAndTypes(
460           rewriter, loc, op, extractOperands, newVecType);
461       result = rewriter.create<vector::InsertStridedSliceOp>(
462           loc, newOp->getResult(0), result, offsets, strides);
463     }
464     rewriter.replaceOp(op, result);
465     return success();
466   }
467 
468 private:
469   vector::UnrollVectorOptions options;
470 };
471 
472 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
473   UnrollReductionPattern(MLIRContext *context,
474                          const vector::UnrollVectorOptions &options,
475                          PatternBenefit benefit = 1)
476       : OpRewritePattern<vector::ReductionOp>(context, benefit),
477         options(options) {}
478 
479   LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
480                                 PatternRewriter &rewriter) const override {
481     std::optional<SmallVector<int64_t>> targetShape =
482         getTargetShape(options, reductionOp);
483     if (!targetShape)
484       return failure();
485     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
486 
487     // Create unrolled vector reduction.
488     Location loc = reductionOp.getLoc();
489     Value accumulator = nullptr;
490     for (SmallVector<int64_t> offsets :
491          StaticTileOffsetRange(originalSize, *targetShape)) {
492       SmallVector<int64_t> strides(offsets.size(), 1);
493       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
494           loc, reductionOp.getVector(), offsets, *targetShape, strides);
495       Operation *newOp = cloneOpWithOperandsAndTypes(
496           rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
497       Value result = newOp->getResult(0);
498 
499       if (!accumulator) {
500         // This is the first reduction.
501         accumulator = result;
502       } else {
503         // On subsequent reduction, combine with the accumulator.
504         accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
505                                          accumulator, result);
506       }
507     }
508 
509     rewriter.replaceOp(reductionOp, accumulator);
510     return success();
511   }
512 
513 private:
514   const vector::UnrollVectorOptions options;
515 };
516 
517 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
518   UnrollTransposePattern(MLIRContext *context,
519                          const vector::UnrollVectorOptions &options,
520                          PatternBenefit benefit = 1)
521       : OpRewritePattern<vector::TransposeOp>(context, benefit),
522         options(options) {}
523 
524   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
525                                 PatternRewriter &rewriter) const override {
526     if (transposeOp.getResultVectorType().getRank() == 0)
527       return failure();
528     auto targetShape = getTargetShape(options, transposeOp);
529     if (!targetShape)
530       return failure();
531     auto originalVectorType = transposeOp.getResultVectorType();
532     SmallVector<int64_t> strides(targetShape->size(), 1);
533     Location loc = transposeOp.getLoc();
534     ArrayRef<int64_t> originalSize = originalVectorType.getShape();
535 
536     // Prepare the result vector;
537     Value result = rewriter.create<arith::ConstantOp>(
538         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
539     ArrayRef<int64_t> permutation = transposeOp.getPermutation();
540 
541     // Unroll the computation.
542     for (SmallVector<int64_t> elementOffsets :
543          StaticTileOffsetRange(originalSize, *targetShape)) {
544       SmallVector<int64_t> permutedOffsets(elementOffsets.size());
545       SmallVector<int64_t> permutedShape(elementOffsets.size());
546       // Compute the source offsets and shape.
547       for (auto indices : llvm::enumerate(permutation)) {
548         permutedOffsets[indices.value()] = elementOffsets[indices.index()];
549         permutedShape[indices.value()] = (*targetShape)[indices.index()];
550       }
551       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
552           loc, transposeOp.getVector(), permutedOffsets, permutedShape,
553           strides);
554       Value transposedSlice =
555           rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
556       result = rewriter.create<vector::InsertStridedSliceOp>(
557           loc, transposedSlice, result, elementOffsets, strides);
558     }
559     rewriter.replaceOp(transposeOp, result);
560     return success();
561   }
562 
563 private:
564   vector::UnrollVectorOptions options;
565 };
566 
567 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
568   UnrollGatherPattern(MLIRContext *context,
569                       const vector::UnrollVectorOptions &options,
570                       PatternBenefit benefit = 1)
571       : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
572   }
573 
574   LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
575                                 PatternRewriter &rewriter) const override {
576     VectorType sourceVectorType = gatherOp.getVectorType();
577     if (sourceVectorType.getRank() == 0)
578       return failure();
579     auto targetShape = getTargetShape(options, gatherOp);
580     if (!targetShape)
581       return failure();
582     SmallVector<int64_t> strides(targetShape->size(), 1);
583     Location loc = gatherOp.getLoc();
584     ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
585 
586     // Prepare the result vector;
587     Value result = rewriter.create<arith::ConstantOp>(
588         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
589     auto targetType =
590         VectorType::get(*targetShape, sourceVectorType.getElementType());
591 
592     SmallVector<int64_t> loopOrder =
593         getUnrollOrder(originalSize.size(), gatherOp, options);
594     for (SmallVector<int64_t> elementOffsets :
595          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
596       // To get the unrolled gather, extract the same slice based on the
597       // decomposed shape from each of the index, mask, and pass-through
598       // vectors.
599       Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
600           loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
601       Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
602           loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
603       Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
604           loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
605       auto slicedGather = rewriter.create<vector::GatherOp>(
606           loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
607           indexSubVec, maskSubVec, passThruSubVec);
608 
609       result = rewriter.create<vector::InsertStridedSliceOp>(
610           loc, slicedGather, result, elementOffsets, strides);
611     }
612     rewriter.replaceOp(gatherOp, result);
613     return success();
614   }
615 
616 private:
617   vector::UnrollVectorOptions options;
618 };
619 
620 } // namespace
621 
622 void mlir::vector::populateVectorUnrollPatterns(
623     RewritePatternSet &patterns, const UnrollVectorOptions &options,
624     PatternBenefit benefit) {
625   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
626                UnrollContractionPattern, UnrollElementwisePattern,
627                UnrollReductionPattern, UnrollMultiReductionPattern,
628                UnrollTransposePattern, UnrollGatherPattern>(
629       patterns.getContext(), options, benefit);
630 }
631