xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (revision 5262865aac683b72f3e66de7a122e0c455ab6b9b)
154db8cc7SThomas Raoux //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===//
254db8cc7SThomas Raoux //
354db8cc7SThomas Raoux // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
454db8cc7SThomas Raoux // See https://llvm.org/LICENSE.txt for license information.
554db8cc7SThomas Raoux // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
654db8cc7SThomas Raoux //
754db8cc7SThomas Raoux //===----------------------------------------------------------------------===//
854db8cc7SThomas Raoux //
954db8cc7SThomas Raoux // This file implements patterns to do vector unrolling and vector distribution.
1054db8cc7SThomas Raoux //
1154db8cc7SThomas Raoux //===----------------------------------------------------------------------===//
1254db8cc7SThomas Raoux 
1354db8cc7SThomas Raoux #include "mlir/Dialect/Affine/IR/AffineOps.h"
1454db8cc7SThomas Raoux #include "mlir/Dialect/Utils/IndexingUtils.h"
1554db8cc7SThomas Raoux #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1654db8cc7SThomas Raoux #include "mlir/IR/ImplicitLocOpBuilder.h"
1754db8cc7SThomas Raoux #include "mlir/Interfaces/VectorInterfaces.h"
1854db8cc7SThomas Raoux #include "llvm/ADT/MapVector.h"
1954db8cc7SThomas Raoux #include "llvm/ADT/STLExtras.h"
20e35ff260SNicolas Vasilache #include "llvm/Support/Debug.h"
2154db8cc7SThomas Raoux #include <numeric>
22a1fe1f5fSKazu Hirata #include <optional>
2354db8cc7SThomas Raoux 
24e35ff260SNicolas Vasilache #define DEBUG_TYPE "vector-unroll"
25e35ff260SNicolas Vasilache #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26e35ff260SNicolas Vasilache #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
2754db8cc7SThomas Raoux 
2854db8cc7SThomas Raoux using namespace mlir;
2954db8cc7SThomas Raoux using namespace mlir::vector;
3054db8cc7SThomas Raoux 
3154db8cc7SThomas Raoux /// Compute the indices of the slice `index` for a tranfer op.
3254db8cc7SThomas Raoux static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
3354db8cc7SThomas Raoux                                                ArrayRef<Value> indices,
3454db8cc7SThomas Raoux                                                AffineMap permutationMap,
3554db8cc7SThomas Raoux                                                Location loc,
3654db8cc7SThomas Raoux                                                OpBuilder &builder) {
3754db8cc7SThomas Raoux   MLIRContext *ctx = builder.getContext();
3854db8cc7SThomas Raoux   auto isBroadcast = [](AffineExpr expr) {
391609f1c2Slong.chen     if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
4054db8cc7SThomas Raoux       return constExpr.getValue() == 0;
4154db8cc7SThomas Raoux     return false;
4254db8cc7SThomas Raoux   };
4354db8cc7SThomas Raoux   // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'.
44*5262865aSKazu Hirata   SmallVector<Value> slicedIndices(indices);
4554db8cc7SThomas Raoux   for (const auto &dim : llvm::enumerate(permutationMap.getResults())) {
4654db8cc7SThomas Raoux     if (isBroadcast(dim.value()))
4754db8cc7SThomas Raoux       continue;
481609f1c2Slong.chen     unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
4954db8cc7SThomas Raoux     auto expr = getAffineDimExpr(0, builder.getContext()) +
5054db8cc7SThomas Raoux                 getAffineConstantExpr(elementOffsets[dim.index()], ctx);
5154db8cc7SThomas Raoux     auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
524c48f016SMatthias Springer     slicedIndices[pos] =
534c48f016SMatthias Springer         builder.create<affine::AffineApplyOp>(loc, map, indices[pos]);
5454db8cc7SThomas Raoux   }
5554db8cc7SThomas Raoux   return slicedIndices;
5654db8cc7SThomas Raoux }
5754db8cc7SThomas Raoux 
5854db8cc7SThomas Raoux // Clones `op` into a new operations that takes `operands` and returns
5954db8cc7SThomas Raoux // `resultTypes`.
6054db8cc7SThomas Raoux static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
6154db8cc7SThomas Raoux                                               Operation *op,
6254db8cc7SThomas Raoux                                               ArrayRef<Value> operands,
6354db8cc7SThomas Raoux                                               ArrayRef<Type> resultTypes) {
6454db8cc7SThomas Raoux   return builder.create(loc, op->getName().getIdentifier(), operands,
6554db8cc7SThomas Raoux                         resultTypes, op->getAttrs());
6654db8cc7SThomas Raoux }
6754db8cc7SThomas Raoux 
6870c73d1bSKazu Hirata /// Return the target shape for unrolling for the given `op`. Return
6970c73d1bSKazu Hirata /// std::nullopt if the op shouldn't be or cannot be unrolled.
700a81ace0SKazu Hirata static std::optional<SmallVector<int64_t>>
7154db8cc7SThomas Raoux getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) {
72e35ff260SNicolas Vasilache   LDBG("");
73e35ff260SNicolas Vasilache   LDBG("Get unroll shape for op " << op->getName().getStringRef());
74e35ff260SNicolas Vasilache   if (options.filterConstraint && failed(options.filterConstraint(op))) {
75e35ff260SNicolas Vasilache     LDBG("--no filter constraint -> BAIL");
761a36588eSKazu Hirata     return std::nullopt;
77e35ff260SNicolas Vasilache   }
7854db8cc7SThomas Raoux   assert(options.nativeShape &&
7954db8cc7SThomas Raoux          "vector unrolling expects the native shape or native"
8054db8cc7SThomas Raoux          "shape call back function to be set");
8154db8cc7SThomas Raoux   auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
82e35ff260SNicolas Vasilache   if (!unrollableVectorOp) {
83e35ff260SNicolas Vasilache     LDBG("--not an unrollable op -> BAIL");
841a36588eSKazu Hirata     return std::nullopt;
85e35ff260SNicolas Vasilache   }
8654db8cc7SThomas Raoux   auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
87e35ff260SNicolas Vasilache   if (!maybeUnrollShape) {
88e35ff260SNicolas Vasilache     LDBG("--could not get shape of op " << *op << " -> BAIL");
891a36588eSKazu Hirata     return std::nullopt;
90e35ff260SNicolas Vasilache   }
91e35ff260SNicolas Vasilache   LLVM_DEBUG(
92e35ff260SNicolas Vasilache       llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: ");
93e35ff260SNicolas Vasilache       llvm::dbgs() << "\n";);
94e35ff260SNicolas Vasilache 
950a81ace0SKazu Hirata   std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op);
96e35ff260SNicolas Vasilache   if (!targetShape) {
97e35ff260SNicolas Vasilache     LDBG("--no unrolling target shape defined " << *op << "-> SKIP");
981a36588eSKazu Hirata     return std::nullopt;
99e35ff260SNicolas Vasilache   }
100e35ff260SNicolas Vasilache   LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: ");
101e35ff260SNicolas Vasilache              llvm::dbgs() << "\n";);
102e35ff260SNicolas Vasilache 
1037a69a9d7SNicolas Vasilache   auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape);
104e35ff260SNicolas Vasilache   if (!maybeShapeRatio) {
105e35ff260SNicolas Vasilache     LDBG("--could not compute integral shape ratio -> BAIL");
1061a36588eSKazu Hirata     return std::nullopt;
107e35ff260SNicolas Vasilache   }
108e35ff260SNicolas Vasilache   if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
109e35ff260SNicolas Vasilache     LDBG("--no unrolling needed -> SKIP");
110e35ff260SNicolas Vasilache     return std::nullopt;
111e35ff260SNicolas Vasilache   }
112e35ff260SNicolas Vasilache   LDBG("--found an integral shape ratio to unroll to -> SUCCESS");
11354db8cc7SThomas Raoux   return targetShape;
11454db8cc7SThomas Raoux }
11554db8cc7SThomas Raoux 
11654db8cc7SThomas Raoux static SmallVector<int64_t>
11754db8cc7SThomas Raoux getUnrollOrder(unsigned numLoops, Operation *op,
11854db8cc7SThomas Raoux                const vector::UnrollVectorOptions &options) {
11954db8cc7SThomas Raoux   SmallVector<int64_t> loopOrder =
12054db8cc7SThomas Raoux       llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops)));
12154db8cc7SThomas Raoux   if (options.traversalOrderCallback != nullptr) {
1220a81ace0SKazu Hirata     std::optional<SmallVector<int64_t>> order =
1230a81ace0SKazu Hirata         options.traversalOrderCallback(op);
12454db8cc7SThomas Raoux     if (order) {
12554db8cc7SThomas Raoux       loopOrder = std::move(*order);
12654db8cc7SThomas Raoux     }
12754db8cc7SThomas Raoux   }
12854db8cc7SThomas Raoux   return loopOrder;
12954db8cc7SThomas Raoux }
13054db8cc7SThomas Raoux 
13154db8cc7SThomas Raoux namespace {
13254db8cc7SThomas Raoux 
13354db8cc7SThomas Raoux struct UnrollTransferReadPattern
13454db8cc7SThomas Raoux     : public OpRewritePattern<vector::TransferReadOp> {
13554db8cc7SThomas Raoux   UnrollTransferReadPattern(MLIRContext *context,
13654db8cc7SThomas Raoux                             const vector::UnrollVectorOptions &options,
13754db8cc7SThomas Raoux                             PatternBenefit benefit = 1)
13854db8cc7SThomas Raoux       : OpRewritePattern<vector::TransferReadOp>(context, benefit),
13954db8cc7SThomas Raoux         options(options) {}
14054db8cc7SThomas Raoux 
14154db8cc7SThomas Raoux   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
14254db8cc7SThomas Raoux                                 PatternRewriter &rewriter) const override {
14354db8cc7SThomas Raoux     // TODO: support 0-d corner case.
14454db8cc7SThomas Raoux     if (readOp.getTransferRank() == 0)
14554db8cc7SThomas Raoux       return failure();
14654db8cc7SThomas Raoux     if (readOp.getMask())
14754db8cc7SThomas Raoux       return failure();
14854db8cc7SThomas Raoux     auto targetShape = getTargetShape(options, readOp);
14954db8cc7SThomas Raoux     if (!targetShape)
15054db8cc7SThomas Raoux       return failure();
15154db8cc7SThomas Raoux     auto sourceVectorType = readOp.getVectorType();
1527a69a9d7SNicolas Vasilache     SmallVector<int64_t> strides(targetShape->size(), 1);
15354db8cc7SThomas Raoux     Location loc = readOp.getLoc();
15454db8cc7SThomas Raoux     ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
15554db8cc7SThomas Raoux 
15654db8cc7SThomas Raoux     // Prepare the result vector;
15754db8cc7SThomas Raoux     Value result = rewriter.create<arith::ConstantOp>(
15854db8cc7SThomas Raoux         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
15954db8cc7SThomas Raoux     auto targetType =
16054db8cc7SThomas Raoux         VectorType::get(*targetShape, sourceVectorType.getElementType());
1617a69a9d7SNicolas Vasilache     SmallVector<Value> originalIndices(readOp.getIndices().begin(),
16254db8cc7SThomas Raoux                                        readOp.getIndices().end());
16354db8cc7SThomas Raoux     SmallVector<int64_t> loopOrder =
16454db8cc7SThomas Raoux         getUnrollOrder(originalSize.size(), readOp, options);
165831041beSChristopher Bate     for (SmallVector<int64_t> elementOffsets :
166831041beSChristopher Bate          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
1677a69a9d7SNicolas Vasilache       SmallVector<Value> indices =
16854db8cc7SThomas Raoux           sliceTransferIndices(elementOffsets, originalIndices,
16954db8cc7SThomas Raoux                                readOp.getPermutationMap(), loc, rewriter);
17054db8cc7SThomas Raoux       auto slicedRead = rewriter.create<vector::TransferReadOp>(
17154db8cc7SThomas Raoux           loc, targetType, readOp.getSource(), indices,
17254db8cc7SThomas Raoux           readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
17354db8cc7SThomas Raoux           readOp.getInBoundsAttr());
17454db8cc7SThomas Raoux 
17554db8cc7SThomas Raoux       result = rewriter.create<vector::InsertStridedSliceOp>(
17654db8cc7SThomas Raoux           loc, slicedRead, result, elementOffsets, strides);
17754db8cc7SThomas Raoux     }
17854db8cc7SThomas Raoux     rewriter.replaceOp(readOp, result);
17954db8cc7SThomas Raoux     return success();
18054db8cc7SThomas Raoux   }
18154db8cc7SThomas Raoux 
18254db8cc7SThomas Raoux private:
18354db8cc7SThomas Raoux   vector::UnrollVectorOptions options;
18454db8cc7SThomas Raoux };
18554db8cc7SThomas Raoux 
18654db8cc7SThomas Raoux struct UnrollTransferWritePattern
18754db8cc7SThomas Raoux     : public OpRewritePattern<vector::TransferWriteOp> {
18854db8cc7SThomas Raoux   UnrollTransferWritePattern(MLIRContext *context,
18954db8cc7SThomas Raoux                              const vector::UnrollVectorOptions &options,
19054db8cc7SThomas Raoux                              PatternBenefit benefit = 1)
19154db8cc7SThomas Raoux       : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
19254db8cc7SThomas Raoux         options(options) {}
19354db8cc7SThomas Raoux 
19454db8cc7SThomas Raoux   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
19554db8cc7SThomas Raoux                                 PatternRewriter &rewriter) const override {
19654db8cc7SThomas Raoux     // TODO: support 0-d corner case.
19754db8cc7SThomas Raoux     if (writeOp.getTransferRank() == 0)
19854db8cc7SThomas Raoux       return failure();
19954db8cc7SThomas Raoux 
20054db8cc7SThomas Raoux     if (writeOp.getMask())
20154db8cc7SThomas Raoux       return failure();
20254db8cc7SThomas Raoux     auto targetShape = getTargetShape(options, writeOp);
20354db8cc7SThomas Raoux     if (!targetShape)
20454db8cc7SThomas Raoux       return failure();
20554db8cc7SThomas Raoux     auto sourceVectorType = writeOp.getVectorType();
2067a69a9d7SNicolas Vasilache     SmallVector<int64_t> strides(targetShape->size(), 1);
20754db8cc7SThomas Raoux     Location loc = writeOp.getLoc();
20854db8cc7SThomas Raoux     ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
2097a69a9d7SNicolas Vasilache     SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
21054db8cc7SThomas Raoux                                        writeOp.getIndices().end());
21154db8cc7SThomas Raoux     SmallVector<int64_t> loopOrder =
21254db8cc7SThomas Raoux         getUnrollOrder(originalSize.size(), writeOp, options);
21354db8cc7SThomas Raoux     Value resultTensor;
214831041beSChristopher Bate     for (SmallVector<int64_t> elementOffsets :
215831041beSChristopher Bate          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
21654db8cc7SThomas Raoux       Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>(
21754db8cc7SThomas Raoux           loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
2187a69a9d7SNicolas Vasilache       SmallVector<Value> indices =
21954db8cc7SThomas Raoux           sliceTransferIndices(elementOffsets, originalIndices,
22054db8cc7SThomas Raoux                                writeOp.getPermutationMap(), loc, rewriter);
22154db8cc7SThomas Raoux       Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
22254db8cc7SThomas Raoux           loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
22354db8cc7SThomas Raoux           indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
22454db8cc7SThomas Raoux       // For the tensor case update the destination for the next transfer write.
22554db8cc7SThomas Raoux       if (!slicedWrite->getResults().empty())
22654db8cc7SThomas Raoux         resultTensor = slicedWrite->getResult(0);
22754db8cc7SThomas Raoux     }
22854db8cc7SThomas Raoux     if (resultTensor)
22954db8cc7SThomas Raoux       rewriter.replaceOp(writeOp, resultTensor);
23054db8cc7SThomas Raoux     else
23154db8cc7SThomas Raoux       rewriter.eraseOp(writeOp);
23254db8cc7SThomas Raoux     return success();
23354db8cc7SThomas Raoux   }
23454db8cc7SThomas Raoux 
23554db8cc7SThomas Raoux private:
23654db8cc7SThomas Raoux   vector::UnrollVectorOptions options;
23754db8cc7SThomas Raoux };
23854db8cc7SThomas Raoux 
23954db8cc7SThomas Raoux struct OffsetMapInfo {
24054db8cc7SThomas Raoux   static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; }
24154db8cc7SThomas Raoux 
24254db8cc7SThomas Raoux   static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; }
24354db8cc7SThomas Raoux 
24454db8cc7SThomas Raoux   static unsigned getHashValue(const SmallVector<int64_t> &v) {
24554db8cc7SThomas Raoux     return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end()));
24654db8cc7SThomas Raoux   }
24754db8cc7SThomas Raoux 
24854db8cc7SThomas Raoux   static bool isEqual(const SmallVector<int64_t> &lhs,
24954db8cc7SThomas Raoux                       const SmallVector<int64_t> &rhs) {
25054db8cc7SThomas Raoux     return lhs == rhs;
25154db8cc7SThomas Raoux   }
25254db8cc7SThomas Raoux };
25354db8cc7SThomas Raoux 
25454db8cc7SThomas Raoux struct UnrollContractionPattern
25554db8cc7SThomas Raoux     : public OpRewritePattern<vector::ContractionOp> {
25654db8cc7SThomas Raoux   UnrollContractionPattern(MLIRContext *context,
25754db8cc7SThomas Raoux                            const vector::UnrollVectorOptions &options,
25854db8cc7SThomas Raoux                            PatternBenefit benefit = 1)
25954db8cc7SThomas Raoux       : OpRewritePattern<vector::ContractionOp>(context, benefit),
26054db8cc7SThomas Raoux         options(options) {}
26154db8cc7SThomas Raoux 
26254db8cc7SThomas Raoux   LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
26354db8cc7SThomas Raoux                                 PatternRewriter &rewriter) const override {
26454db8cc7SThomas Raoux     auto targetShape = getTargetShape(options, contractOp);
26554db8cc7SThomas Raoux     if (!targetShape)
26654db8cc7SThomas Raoux       return failure();
2675550c821STres Popp     auto dstVecType = cast<VectorType>(contractOp.getResultType());
2687a69a9d7SNicolas Vasilache     SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
26954db8cc7SThomas Raoux 
27054db8cc7SThomas Raoux     Location loc = contractOp.getLoc();
27154db8cc7SThomas Raoux     unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
27254db8cc7SThomas Raoux     AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
27354db8cc7SThomas Raoux     llvm::MapVector<
27454db8cc7SThomas Raoux         SmallVector<int64_t>, Value,
27554db8cc7SThomas Raoux         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
27654db8cc7SThomas Raoux         accCache;
27754db8cc7SThomas Raoux 
27854db8cc7SThomas Raoux     SmallVector<int64_t> loopOrder = getUnrollOrder(
27954db8cc7SThomas Raoux         contractOp.getIteratorTypes().size(), contractOp, options);
280831041beSChristopher Bate 
281831041beSChristopher Bate     for (SmallVector<int64_t> offsets :
282831041beSChristopher Bate          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
2837a69a9d7SNicolas Vasilache       SmallVector<Value> slicesOperands(contractOp.getNumOperands());
28454db8cc7SThomas Raoux 
2857a69a9d7SNicolas Vasilache       // Helper to compute the new shape of each operand and extract the slice.
28654db8cc7SThomas Raoux       auto extractOperand = [&](unsigned index, Value operand,
28754db8cc7SThomas Raoux                                 AffineMap permutationMap,
28854db8cc7SThomas Raoux                                 ArrayRef<int64_t> operandOffets) {
28954db8cc7SThomas Raoux         SmallVector<int64_t> operandShape = applyPermutationMap(
29054db8cc7SThomas Raoux             permutationMap, ArrayRef<int64_t>(*targetShape));
2917a69a9d7SNicolas Vasilache         SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
29254db8cc7SThomas Raoux         slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>(
29354db8cc7SThomas Raoux             loc, operand, operandOffets, operandShape, operandStrides);
29454db8cc7SThomas Raoux       };
29554db8cc7SThomas Raoux 
29654db8cc7SThomas Raoux       // Extract the new lhs operand.
29754db8cc7SThomas Raoux       AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
29854db8cc7SThomas Raoux       SmallVector<int64_t> lhsOffets =
29954db8cc7SThomas Raoux           applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
30054db8cc7SThomas Raoux       extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
30154db8cc7SThomas Raoux 
30254db8cc7SThomas Raoux       // Extract the new rhs operand.
30354db8cc7SThomas Raoux       AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
30454db8cc7SThomas Raoux       SmallVector<int64_t> rhsOffets =
30554db8cc7SThomas Raoux           applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
30654db8cc7SThomas Raoux       extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
30754db8cc7SThomas Raoux 
30854db8cc7SThomas Raoux       AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
30954db8cc7SThomas Raoux       SmallVector<int64_t> accOffets =
31054db8cc7SThomas Raoux           applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
31154db8cc7SThomas Raoux       // If a version of the accumulator has already been computed, use it
31254db8cc7SThomas Raoux       // otherwise extract the first version from the original operand.
31389dc313aSMehdi Amini       auto *accIt = accCache.find(accOffets);
31454db8cc7SThomas Raoux       if (accIt != accCache.end())
31554db8cc7SThomas Raoux         slicesOperands[2] = accIt->second;
31654db8cc7SThomas Raoux       else
31754db8cc7SThomas Raoux         extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
31854db8cc7SThomas Raoux 
31954db8cc7SThomas Raoux       SmallVector<int64_t> dstShape =
32054db8cc7SThomas Raoux           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape));
32154db8cc7SThomas Raoux       auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
32254db8cc7SThomas Raoux       Operation *newOp = cloneOpWithOperandsAndTypes(
32354db8cc7SThomas Raoux           rewriter, loc, contractOp, slicesOperands, targetType);
32454db8cc7SThomas Raoux 
32554db8cc7SThomas Raoux       SmallVector<int64_t> dstOffets =
32654db8cc7SThomas Raoux           applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets));
32754db8cc7SThomas Raoux       // Save the accumulated value untill all the loops are unrolled since
32854db8cc7SThomas Raoux       // reduction loop keep updating the accumulator.
32954db8cc7SThomas Raoux       accCache[dstOffets] = newOp->getResult(0);
33054db8cc7SThomas Raoux     }
33154db8cc7SThomas Raoux     // Assemble back the accumulator into a single vector.
33254db8cc7SThomas Raoux     Value result = rewriter.create<arith::ConstantOp>(
33354db8cc7SThomas Raoux         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
33454db8cc7SThomas Raoux     for (const auto &it : accCache) {
33554db8cc7SThomas Raoux       SmallVector<int64_t> dstStrides(it.first.size(), 1);
33654db8cc7SThomas Raoux       result = rewriter.create<vector::InsertStridedSliceOp>(
33754db8cc7SThomas Raoux           loc, it.second, result, it.first, dstStrides);
33854db8cc7SThomas Raoux     }
33954db8cc7SThomas Raoux     rewriter.replaceOp(contractOp, result);
34054db8cc7SThomas Raoux     return success();
34154db8cc7SThomas Raoux   }
34254db8cc7SThomas Raoux 
34354db8cc7SThomas Raoux private:
34454db8cc7SThomas Raoux   vector::UnrollVectorOptions options;
34554db8cc7SThomas Raoux };
34654db8cc7SThomas Raoux 
34754db8cc7SThomas Raoux struct UnrollMultiReductionPattern
34854db8cc7SThomas Raoux     : public OpRewritePattern<vector::MultiDimReductionOp> {
34954db8cc7SThomas Raoux   UnrollMultiReductionPattern(MLIRContext *context,
35054db8cc7SThomas Raoux                               const vector::UnrollVectorOptions &options,
35154db8cc7SThomas Raoux                               PatternBenefit benefit = 1)
35254db8cc7SThomas Raoux       : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
35354db8cc7SThomas Raoux         options(options) {}
35454db8cc7SThomas Raoux 
35554db8cc7SThomas Raoux   LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
35654db8cc7SThomas Raoux                                 PatternRewriter &rewriter) const override {
3570a81ace0SKazu Hirata     std::optional<SmallVector<int64_t>> targetShape =
35854db8cc7SThomas Raoux         getTargetShape(options, reductionOp);
35954db8cc7SThomas Raoux     if (!targetShape)
36054db8cc7SThomas Raoux       return failure();
3617a69a9d7SNicolas Vasilache     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
36254db8cc7SThomas Raoux     llvm::MapVector<
36354db8cc7SThomas Raoux         SmallVector<int64_t>, Value,
36454db8cc7SThomas Raoux         llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
36554db8cc7SThomas Raoux         accCache;
36654db8cc7SThomas Raoux     Location loc = reductionOp.getLoc();
3677a69a9d7SNicolas Vasilache 
3687a69a9d7SNicolas Vasilache     // Stride of the ratios, this gives us the offsets of sliceCount in a basis
3697a69a9d7SNicolas Vasilache     // of multiples of the targetShape.
370831041beSChristopher Bate     for (SmallVector<int64_t> offsets :
371831041beSChristopher Bate          StaticTileOffsetRange(originalSize, *targetShape)) {
37254db8cc7SThomas Raoux       SmallVector<Value> operands;
3737a69a9d7SNicolas Vasilache       SmallVector<int64_t> operandStrides(offsets.size(), 1);
37454db8cc7SThomas Raoux       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
37554db8cc7SThomas Raoux           loc, reductionOp.getSource(), offsets, *targetShape, operandStrides);
37654db8cc7SThomas Raoux       operands.push_back(slicedOperand);
37754db8cc7SThomas Raoux       SmallVector<int64_t> dstShape;
37854db8cc7SThomas Raoux       SmallVector<int64_t> destOffset;
37954db8cc7SThomas Raoux       for (size_t i : llvm::seq(size_t(0), targetShape->size())) {
38054db8cc7SThomas Raoux         if (!reductionOp.isReducedDim(i)) {
38154db8cc7SThomas Raoux           destOffset.push_back(offsets[i]);
38254db8cc7SThomas Raoux           dstShape.push_back((*targetShape)[i]);
38354db8cc7SThomas Raoux         }
38454db8cc7SThomas Raoux       }
38554db8cc7SThomas Raoux       Value acc;
3867a69a9d7SNicolas Vasilache       SmallVector<int64_t> accStrides(destOffset.size(), 1);
38754db8cc7SThomas Raoux       // If a version of the accumulator has already been computed, use it
38854db8cc7SThomas Raoux       // otherwise extract the first version from the original operand.
38989dc313aSMehdi Amini       auto *accIt = accCache.find(destOffset);
39054db8cc7SThomas Raoux       if (accIt != accCache.end())
39154db8cc7SThomas Raoux         acc = accIt->second;
39254db8cc7SThomas Raoux       else
39354db8cc7SThomas Raoux         acc = rewriter.create<vector::ExtractStridedSliceOp>(
39454db8cc7SThomas Raoux             loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
39554db8cc7SThomas Raoux       operands.push_back(acc);
39654db8cc7SThomas Raoux       auto targetType = VectorType::get(
39754db8cc7SThomas Raoux           dstShape, reductionOp.getSourceVectorType().getElementType());
39854db8cc7SThomas Raoux       Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp,
39954db8cc7SThomas Raoux                                                      operands, targetType);
40054db8cc7SThomas Raoux       Value result = newOp->getResult(0);
40154db8cc7SThomas Raoux       accCache[destOffset] = result;
40254db8cc7SThomas Raoux     }
40354db8cc7SThomas Raoux     // Assemble back the accumulator into a single vector.
40454db8cc7SThomas Raoux     Value result = rewriter.create<arith::ConstantOp>(
40554db8cc7SThomas Raoux         loc, reductionOp.getDestType(),
40654db8cc7SThomas Raoux         rewriter.getZeroAttr(reductionOp.getDestType()));
40754db8cc7SThomas Raoux     for (const auto &it : accCache) {
40854db8cc7SThomas Raoux       SmallVector<int64_t> dstStrides(it.first.size(), 1);
40954db8cc7SThomas Raoux       result = rewriter.create<vector::InsertStridedSliceOp>(
41054db8cc7SThomas Raoux           loc, it.second, result, it.first, dstStrides);
41154db8cc7SThomas Raoux     }
41254db8cc7SThomas Raoux     rewriter.replaceOp(reductionOp, result);
41354db8cc7SThomas Raoux     return success();
41454db8cc7SThomas Raoux   }
41554db8cc7SThomas Raoux 
41654db8cc7SThomas Raoux private:
41754db8cc7SThomas Raoux   vector::UnrollVectorOptions options;
41854db8cc7SThomas Raoux };
41954db8cc7SThomas Raoux 
42054db8cc7SThomas Raoux struct UnrollElementwisePattern : public RewritePattern {
42154db8cc7SThomas Raoux   UnrollElementwisePattern(MLIRContext *context,
42254db8cc7SThomas Raoux                            const vector::UnrollVectorOptions &options,
42354db8cc7SThomas Raoux                            PatternBenefit benefit = 1)
42454db8cc7SThomas Raoux       : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
42554db8cc7SThomas Raoux         options(options) {}
42654db8cc7SThomas Raoux 
42754db8cc7SThomas Raoux   LogicalResult matchAndRewrite(Operation *op,
42854db8cc7SThomas Raoux                                 PatternRewriter &rewriter) const override {
42954db8cc7SThomas Raoux     if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
43054db8cc7SThomas Raoux       return failure();
43154db8cc7SThomas Raoux     auto targetShape = getTargetShape(options, op);
43254db8cc7SThomas Raoux     if (!targetShape)
43354db8cc7SThomas Raoux       return failure();
4345550c821STres Popp     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
4357a69a9d7SNicolas Vasilache     SmallVector<int64_t> originalSize =
43654db8cc7SThomas Raoux         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
43754db8cc7SThomas Raoux     Location loc = op->getLoc();
43854db8cc7SThomas Raoux     // Prepare the result vector.
43954db8cc7SThomas Raoux     Value result = rewriter.create<arith::ConstantOp>(
44054db8cc7SThomas Raoux         loc, dstVecType, rewriter.getZeroAttr(dstVecType));
4417a69a9d7SNicolas Vasilache     SmallVector<int64_t> strides(targetShape->size(), 1);
44254db8cc7SThomas Raoux     VectorType newVecType =
44354db8cc7SThomas Raoux         VectorType::get(*targetShape, dstVecType.getElementType());
4447a69a9d7SNicolas Vasilache 
445831041beSChristopher Bate     // Create the unrolled computation.
446831041beSChristopher Bate     for (SmallVector<int64_t> offsets :
447831041beSChristopher Bate          StaticTileOffsetRange(originalSize, *targetShape)) {
4487a69a9d7SNicolas Vasilache       SmallVector<Value> extractOperands;
44954db8cc7SThomas Raoux       for (OpOperand &operand : op->getOpOperands()) {
4505550c821STres Popp         auto vecType = dyn_cast<VectorType>(operand.get().getType());
45154db8cc7SThomas Raoux         if (!vecType) {
45254db8cc7SThomas Raoux           extractOperands.push_back(operand.get());
45354db8cc7SThomas Raoux           continue;
45454db8cc7SThomas Raoux         }
45554db8cc7SThomas Raoux         extractOperands.push_back(
45654db8cc7SThomas Raoux             rewriter.create<vector::ExtractStridedSliceOp>(
45754db8cc7SThomas Raoux                 loc, operand.get(), offsets, *targetShape, strides));
45854db8cc7SThomas Raoux       }
45954db8cc7SThomas Raoux       Operation *newOp = cloneOpWithOperandsAndTypes(
46054db8cc7SThomas Raoux           rewriter, loc, op, extractOperands, newVecType);
46154db8cc7SThomas Raoux       result = rewriter.create<vector::InsertStridedSliceOp>(
46254db8cc7SThomas Raoux           loc, newOp->getResult(0), result, offsets, strides);
46354db8cc7SThomas Raoux     }
46454db8cc7SThomas Raoux     rewriter.replaceOp(op, result);
46554db8cc7SThomas Raoux     return success();
46654db8cc7SThomas Raoux   }
46754db8cc7SThomas Raoux 
46854db8cc7SThomas Raoux private:
46954db8cc7SThomas Raoux   vector::UnrollVectorOptions options;
47054db8cc7SThomas Raoux };
47154db8cc7SThomas Raoux 
47254db8cc7SThomas Raoux struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> {
47354db8cc7SThomas Raoux   UnrollReductionPattern(MLIRContext *context,
47454db8cc7SThomas Raoux                          const vector::UnrollVectorOptions &options,
47554db8cc7SThomas Raoux                          PatternBenefit benefit = 1)
47654db8cc7SThomas Raoux       : OpRewritePattern<vector::ReductionOp>(context, benefit),
47754db8cc7SThomas Raoux         options(options) {}
47854db8cc7SThomas Raoux 
47954db8cc7SThomas Raoux   LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
48054db8cc7SThomas Raoux                                 PatternRewriter &rewriter) const override {
4810a81ace0SKazu Hirata     std::optional<SmallVector<int64_t>> targetShape =
48254db8cc7SThomas Raoux         getTargetShape(options, reductionOp);
48354db8cc7SThomas Raoux     if (!targetShape)
48454db8cc7SThomas Raoux       return failure();
48554db8cc7SThomas Raoux     SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
48654db8cc7SThomas Raoux 
48754db8cc7SThomas Raoux     // Create unrolled vector reduction.
48854db8cc7SThomas Raoux     Location loc = reductionOp.getLoc();
48954db8cc7SThomas Raoux     Value accumulator = nullptr;
490831041beSChristopher Bate     for (SmallVector<int64_t> offsets :
491831041beSChristopher Bate          StaticTileOffsetRange(originalSize, *targetShape)) {
49254db8cc7SThomas Raoux       SmallVector<int64_t> strides(offsets.size(), 1);
49354db8cc7SThomas Raoux       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
49454db8cc7SThomas Raoux           loc, reductionOp.getVector(), offsets, *targetShape, strides);
49554db8cc7SThomas Raoux       Operation *newOp = cloneOpWithOperandsAndTypes(
49654db8cc7SThomas Raoux           rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
49754db8cc7SThomas Raoux       Value result = newOp->getResult(0);
49854db8cc7SThomas Raoux 
49954db8cc7SThomas Raoux       if (!accumulator) {
50054db8cc7SThomas Raoux         // This is the first reduction.
50154db8cc7SThomas Raoux         accumulator = result;
50254db8cc7SThomas Raoux       } else {
50354db8cc7SThomas Raoux         // On subsequent reduction, combine with the accumulator.
50454db8cc7SThomas Raoux         accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(),
50554db8cc7SThomas Raoux                                          accumulator, result);
50654db8cc7SThomas Raoux       }
50754db8cc7SThomas Raoux     }
50854db8cc7SThomas Raoux 
50954db8cc7SThomas Raoux     rewriter.replaceOp(reductionOp, accumulator);
51054db8cc7SThomas Raoux     return success();
51154db8cc7SThomas Raoux   }
51254db8cc7SThomas Raoux 
51354db8cc7SThomas Raoux private:
51454db8cc7SThomas Raoux   const vector::UnrollVectorOptions options;
51554db8cc7SThomas Raoux };
51654db8cc7SThomas Raoux 
517490c77e4Sjacquesguan struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
518490c77e4Sjacquesguan   UnrollTransposePattern(MLIRContext *context,
51954db8cc7SThomas Raoux                          const vector::UnrollVectorOptions &options,
52054db8cc7SThomas Raoux                          PatternBenefit benefit = 1)
52154db8cc7SThomas Raoux       : OpRewritePattern<vector::TransposeOp>(context, benefit),
52254db8cc7SThomas Raoux         options(options) {}
52354db8cc7SThomas Raoux 
524490c77e4Sjacquesguan   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
52554db8cc7SThomas Raoux                                 PatternRewriter &rewriter) const override {
526a1aad28dSLei Zhang     if (transposeOp.getResultVectorType().getRank() == 0)
52754db8cc7SThomas Raoux       return failure();
528490c77e4Sjacquesguan     auto targetShape = getTargetShape(options, transposeOp);
52954db8cc7SThomas Raoux     if (!targetShape)
53054db8cc7SThomas Raoux       return failure();
531a1aad28dSLei Zhang     auto originalVectorType = transposeOp.getResultVectorType();
5327a69a9d7SNicolas Vasilache     SmallVector<int64_t> strides(targetShape->size(), 1);
533490c77e4Sjacquesguan     Location loc = transposeOp.getLoc();
53454db8cc7SThomas Raoux     ArrayRef<int64_t> originalSize = originalVectorType.getShape();
535831041beSChristopher Bate 
53654db8cc7SThomas Raoux     // Prepare the result vector;
53754db8cc7SThomas Raoux     Value result = rewriter.create<arith::ConstantOp>(
53854db8cc7SThomas Raoux         loc, originalVectorType, rewriter.getZeroAttr(originalVectorType));
53932c3decbSMatthias Springer     ArrayRef<int64_t> permutation = transposeOp.getPermutation();
5407a69a9d7SNicolas Vasilache 
541831041beSChristopher Bate     // Unroll the computation.
542831041beSChristopher Bate     for (SmallVector<int64_t> elementOffsets :
543831041beSChristopher Bate          StaticTileOffsetRange(originalSize, *targetShape)) {
5447a69a9d7SNicolas Vasilache       SmallVector<int64_t> permutedOffsets(elementOffsets.size());
5457a69a9d7SNicolas Vasilache       SmallVector<int64_t> permutedShape(elementOffsets.size());
54654db8cc7SThomas Raoux       // Compute the source offsets and shape.
5478c258fdaSJakub Kuderski       for (auto indices : llvm::enumerate(permutation)) {
54854db8cc7SThomas Raoux         permutedOffsets[indices.value()] = elementOffsets[indices.index()];
54954db8cc7SThomas Raoux         permutedShape[indices.value()] = (*targetShape)[indices.index()];
55054db8cc7SThomas Raoux       }
55154db8cc7SThomas Raoux       Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>(
552490c77e4Sjacquesguan           loc, transposeOp.getVector(), permutedOffsets, permutedShape,
553490c77e4Sjacquesguan           strides);
554490c77e4Sjacquesguan       Value transposedSlice =
55554db8cc7SThomas Raoux           rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation);
55654db8cc7SThomas Raoux       result = rewriter.create<vector::InsertStridedSliceOp>(
557490c77e4Sjacquesguan           loc, transposedSlice, result, elementOffsets, strides);
55854db8cc7SThomas Raoux     }
559490c77e4Sjacquesguan     rewriter.replaceOp(transposeOp, result);
56054db8cc7SThomas Raoux     return success();
56154db8cc7SThomas Raoux   }
56254db8cc7SThomas Raoux 
56354db8cc7SThomas Raoux private:
56454db8cc7SThomas Raoux   vector::UnrollVectorOptions options;
56554db8cc7SThomas Raoux };
56654db8cc7SThomas Raoux 
567435f7d4cSQuinn Dawkins struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
568435f7d4cSQuinn Dawkins   UnrollGatherPattern(MLIRContext *context,
569435f7d4cSQuinn Dawkins                       const vector::UnrollVectorOptions &options,
570435f7d4cSQuinn Dawkins                       PatternBenefit benefit = 1)
571435f7d4cSQuinn Dawkins       : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
572435f7d4cSQuinn Dawkins   }
573435f7d4cSQuinn Dawkins 
574435f7d4cSQuinn Dawkins   LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
575435f7d4cSQuinn Dawkins                                 PatternRewriter &rewriter) const override {
576435f7d4cSQuinn Dawkins     VectorType sourceVectorType = gatherOp.getVectorType();
577435f7d4cSQuinn Dawkins     if (sourceVectorType.getRank() == 0)
578435f7d4cSQuinn Dawkins       return failure();
579435f7d4cSQuinn Dawkins     auto targetShape = getTargetShape(options, gatherOp);
580435f7d4cSQuinn Dawkins     if (!targetShape)
581435f7d4cSQuinn Dawkins       return failure();
582435f7d4cSQuinn Dawkins     SmallVector<int64_t> strides(targetShape->size(), 1);
583435f7d4cSQuinn Dawkins     Location loc = gatherOp.getLoc();
584435f7d4cSQuinn Dawkins     ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
585435f7d4cSQuinn Dawkins 
586435f7d4cSQuinn Dawkins     // Prepare the result vector;
587435f7d4cSQuinn Dawkins     Value result = rewriter.create<arith::ConstantOp>(
588435f7d4cSQuinn Dawkins         loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType));
589435f7d4cSQuinn Dawkins     auto targetType =
590435f7d4cSQuinn Dawkins         VectorType::get(*targetShape, sourceVectorType.getElementType());
591435f7d4cSQuinn Dawkins 
592435f7d4cSQuinn Dawkins     SmallVector<int64_t> loopOrder =
593435f7d4cSQuinn Dawkins         getUnrollOrder(originalSize.size(), gatherOp, options);
594831041beSChristopher Bate     for (SmallVector<int64_t> elementOffsets :
595831041beSChristopher Bate          StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
596435f7d4cSQuinn Dawkins       // To get the unrolled gather, extract the same slice based on the
597435f7d4cSQuinn Dawkins       // decomposed shape from each of the index, mask, and pass-through
598435f7d4cSQuinn Dawkins       // vectors.
599435f7d4cSQuinn Dawkins       Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
600435f7d4cSQuinn Dawkins           loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides);
601435f7d4cSQuinn Dawkins       Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
602435f7d4cSQuinn Dawkins           loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
603435f7d4cSQuinn Dawkins       Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>(
604435f7d4cSQuinn Dawkins           loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides);
605435f7d4cSQuinn Dawkins       auto slicedGather = rewriter.create<vector::GatherOp>(
606435f7d4cSQuinn Dawkins           loc, targetType, gatherOp.getBase(), gatherOp.getIndices(),
607435f7d4cSQuinn Dawkins           indexSubVec, maskSubVec, passThruSubVec);
608435f7d4cSQuinn Dawkins 
609435f7d4cSQuinn Dawkins       result = rewriter.create<vector::InsertStridedSliceOp>(
610435f7d4cSQuinn Dawkins           loc, slicedGather, result, elementOffsets, strides);
611435f7d4cSQuinn Dawkins     }
612435f7d4cSQuinn Dawkins     rewriter.replaceOp(gatherOp, result);
613435f7d4cSQuinn Dawkins     return success();
614435f7d4cSQuinn Dawkins   }
615435f7d4cSQuinn Dawkins 
616435f7d4cSQuinn Dawkins private:
617435f7d4cSQuinn Dawkins   vector::UnrollVectorOptions options;
618435f7d4cSQuinn Dawkins };
619435f7d4cSQuinn Dawkins 
62054db8cc7SThomas Raoux } // namespace
62154db8cc7SThomas Raoux 
62254db8cc7SThomas Raoux void mlir::vector::populateVectorUnrollPatterns(
62354db8cc7SThomas Raoux     RewritePatternSet &patterns, const UnrollVectorOptions &options,
62454db8cc7SThomas Raoux     PatternBenefit benefit) {
62554db8cc7SThomas Raoux   patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
62654db8cc7SThomas Raoux                UnrollContractionPattern, UnrollElementwisePattern,
62754db8cc7SThomas Raoux                UnrollReductionPattern, UnrollMultiReductionPattern,
628435f7d4cSQuinn Dawkins                UnrollTransposePattern, UnrollGatherPattern>(
629435f7d4cSQuinn Dawkins       patterns.getContext(), options, benefit);
63054db8cc7SThomas Raoux }
631