xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (revision a7a4c16c672bdd8e245af533a1f170522e26e42a)
12bc4c3e9SNicolas Vasilache //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===//
22bc4c3e9SNicolas Vasilache //
32bc4c3e9SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42bc4c3e9SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information.
52bc4c3e9SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62bc4c3e9SNicolas Vasilache //
72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
82bc4c3e9SNicolas Vasilache //
92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the
102bc4c3e9SNicolas Vasilache // 'vector.shape_cast' operation.
112bc4c3e9SNicolas Vasilache //
122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===//
132bc4c3e9SNicolas Vasilache 
142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h"
152bc4c3e9SNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h"
162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h"
172bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
182bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
192bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
202bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinTypes.h"
212bc4c3e9SNicolas Vasilache #include "mlir/IR/Location.h"
222bc4c3e9SNicolas Vasilache #include "mlir/IR/PatternMatch.h"
232bc4c3e9SNicolas Vasilache #include "mlir/IR/TypeUtilities.h"
242bc4c3e9SNicolas Vasilache 
252bc4c3e9SNicolas Vasilache #define DEBUG_TYPE "vector-shape-cast-lowering"
262bc4c3e9SNicolas Vasilache 
272bc4c3e9SNicolas Vasilache using namespace mlir;
282bc4c3e9SNicolas Vasilache using namespace mlir::vector;
292bc4c3e9SNicolas Vasilache 
30*a7a4c16cSDiego Caballero /// Increments n-D `indices` by `step` starting from the innermost dimension.
31*a7a4c16cSDiego Caballero static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
32*a7a4c16cSDiego Caballero                    int step = 1) {
33*a7a4c16cSDiego Caballero   for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
34*a7a4c16cSDiego Caballero     assert(indices[dim] < vecType.getDimSize(dim) &&
35*a7a4c16cSDiego Caballero            "Indices are out of bound");
36*a7a4c16cSDiego Caballero     indices[dim] += step;
37*a7a4c16cSDiego Caballero     if (indices[dim] < vecType.getDimSize(dim))
388dffb71cSBenjamin Maxwell       break;
39*a7a4c16cSDiego Caballero 
40*a7a4c16cSDiego Caballero     indices[dim] = 0;
41*a7a4c16cSDiego Caballero     step = 1;
428dffb71cSBenjamin Maxwell   }
438dffb71cSBenjamin Maxwell }
44*a7a4c16cSDiego Caballero 
45*a7a4c16cSDiego Caballero namespace {
46*a7a4c16cSDiego Caballero /// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
47*a7a4c16cSDiego Caballero /// vectors progressively. This iterates over the n-1 major dimensions of the
48*a7a4c16cSDiego Caballero /// n-D vector and performs rewrites into:
49*a7a4c16cSDiego Caballero ///   vector.extract from n-D + vector.insert_strided_slice offset into 1-D
50*a7a4c16cSDiego Caballero class ShapeCastOpNDDownCastRewritePattern
51*a7a4c16cSDiego Caballero     : public OpRewritePattern<vector::ShapeCastOp> {
52*a7a4c16cSDiego Caballero public:
53*a7a4c16cSDiego Caballero   using OpRewritePattern::OpRewritePattern;
54*a7a4c16cSDiego Caballero 
55*a7a4c16cSDiego Caballero   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
56*a7a4c16cSDiego Caballero                                 PatternRewriter &rewriter) const override {
57*a7a4c16cSDiego Caballero     auto sourceVectorType = op.getSourceVectorType();
58*a7a4c16cSDiego Caballero     auto resultVectorType = op.getResultVectorType();
59*a7a4c16cSDiego Caballero     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
60*a7a4c16cSDiego Caballero       return failure();
61*a7a4c16cSDiego Caballero 
62*a7a4c16cSDiego Caballero     int64_t srcRank = sourceVectorType.getRank();
63*a7a4c16cSDiego Caballero     int64_t resRank = resultVectorType.getRank();
64*a7a4c16cSDiego Caballero     if (srcRank < 2 || resRank != 1)
65*a7a4c16cSDiego Caballero       return failure();
66*a7a4c16cSDiego Caballero 
67*a7a4c16cSDiego Caballero     // Compute the number of 1-D vector elements involved in the reshape.
68*a7a4c16cSDiego Caballero     int64_t numElts = 1;
69*a7a4c16cSDiego Caballero     for (int64_t dim = 0; dim < srcRank - 1; ++dim)
70*a7a4c16cSDiego Caballero       numElts *= sourceVectorType.getDimSize(dim);
71*a7a4c16cSDiego Caballero 
72*a7a4c16cSDiego Caballero     auto loc = op.getLoc();
73*a7a4c16cSDiego Caballero     SmallVector<int64_t> srcIdx(srcRank - 1, 0);
74*a7a4c16cSDiego Caballero     SmallVector<int64_t> resIdx(resRank, 0);
75*a7a4c16cSDiego Caballero     int64_t extractSize = sourceVectorType.getShape().back();
76*a7a4c16cSDiego Caballero     Value result = rewriter.create<arith::ConstantOp>(
77*a7a4c16cSDiego Caballero         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
78*a7a4c16cSDiego Caballero 
79*a7a4c16cSDiego Caballero     // Compute the indices of each 1-D vector element of the source extraction
80*a7a4c16cSDiego Caballero     // and destination slice insertion and generate such instructions.
81*a7a4c16cSDiego Caballero     for (int64_t i = 0; i < numElts; ++i) {
82*a7a4c16cSDiego Caballero       if (i != 0) {
83*a7a4c16cSDiego Caballero         incIdx(srcIdx, sourceVectorType, /*step=*/1);
84*a7a4c16cSDiego Caballero         incIdx(resIdx, resultVectorType, /*step=*/extractSize);
858dffb71cSBenjamin Maxwell       }
868dffb71cSBenjamin Maxwell 
87*a7a4c16cSDiego Caballero       Value extract =
88*a7a4c16cSDiego Caballero           rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
89*a7a4c16cSDiego Caballero       result = rewriter.create<vector::InsertStridedSliceOp>(
90*a7a4c16cSDiego Caballero           loc, extract, result,
91*a7a4c16cSDiego Caballero           /*offsets=*/resIdx, /*strides=*/1);
92*a7a4c16cSDiego Caballero     }
93*a7a4c16cSDiego Caballero 
94*a7a4c16cSDiego Caballero     rewriter.replaceOp(op, result);
95*a7a4c16cSDiego Caballero     return success();
96*a7a4c16cSDiego Caballero   }
97*a7a4c16cSDiego Caballero };
98*a7a4c16cSDiego Caballero 
99*a7a4c16cSDiego Caballero /// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
100*a7a4c16cSDiego Caballero /// vectors progressively. This iterates over the n-1 major dimension of the n-D
101*a7a4c16cSDiego Caballero /// vector and performs rewrites into:
102*a7a4c16cSDiego Caballero ///   vector.extract_strided_slice from 1-D + vector.insert into n-D
103*a7a4c16cSDiego Caballero /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
104*a7a4c16cSDiego Caballero class ShapeCastOpNDUpCastRewritePattern
105*a7a4c16cSDiego Caballero     : public OpRewritePattern<vector::ShapeCastOp> {
106*a7a4c16cSDiego Caballero public:
107*a7a4c16cSDiego Caballero   using OpRewritePattern::OpRewritePattern;
108*a7a4c16cSDiego Caballero 
109*a7a4c16cSDiego Caballero   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
110*a7a4c16cSDiego Caballero                                 PatternRewriter &rewriter) const override {
111*a7a4c16cSDiego Caballero     auto sourceVectorType = op.getSourceVectorType();
112*a7a4c16cSDiego Caballero     auto resultVectorType = op.getResultVectorType();
113*a7a4c16cSDiego Caballero     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
114*a7a4c16cSDiego Caballero       return failure();
115*a7a4c16cSDiego Caballero 
116*a7a4c16cSDiego Caballero     int64_t srcRank = sourceVectorType.getRank();
117*a7a4c16cSDiego Caballero     int64_t resRank = resultVectorType.getRank();
118*a7a4c16cSDiego Caballero     if (srcRank != 1 || resRank < 2)
119*a7a4c16cSDiego Caballero       return failure();
120*a7a4c16cSDiego Caballero 
121*a7a4c16cSDiego Caballero     // Compute the number of 1-D vector elements involved in the reshape.
122*a7a4c16cSDiego Caballero     int64_t numElts = 1;
123*a7a4c16cSDiego Caballero     for (int64_t dim = 0; dim < resRank - 1; ++dim)
124*a7a4c16cSDiego Caballero       numElts *= resultVectorType.getDimSize(dim);
125*a7a4c16cSDiego Caballero 
126*a7a4c16cSDiego Caballero     // Compute the indices of each 1-D vector element of the source slice
127*a7a4c16cSDiego Caballero     // extraction and destination insertion and generate such instructions.
128*a7a4c16cSDiego Caballero     auto loc = op.getLoc();
129*a7a4c16cSDiego Caballero     SmallVector<int64_t> srcIdx(srcRank, 0);
130*a7a4c16cSDiego Caballero     SmallVector<int64_t> resIdx(resRank - 1, 0);
131*a7a4c16cSDiego Caballero     int64_t extractSize = resultVectorType.getShape().back();
132*a7a4c16cSDiego Caballero     Value result = rewriter.create<arith::ConstantOp>(
133*a7a4c16cSDiego Caballero         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
134*a7a4c16cSDiego Caballero     for (int64_t i = 0; i < numElts; ++i) {
135*a7a4c16cSDiego Caballero       if (i != 0) {
136*a7a4c16cSDiego Caballero         incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
137*a7a4c16cSDiego Caballero         incIdx(resIdx, resultVectorType, /*step=*/1);
138*a7a4c16cSDiego Caballero       }
139*a7a4c16cSDiego Caballero 
140*a7a4c16cSDiego Caballero       Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
141*a7a4c16cSDiego Caballero           loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
142*a7a4c16cSDiego Caballero           /*strides=*/1);
143*a7a4c16cSDiego Caballero       result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
144*a7a4c16cSDiego Caballero     }
145*a7a4c16cSDiego Caballero     rewriter.replaceOp(op, result);
146*a7a4c16cSDiego Caballero     return success();
147*a7a4c16cSDiego Caballero   }
148*a7a4c16cSDiego Caballero };
149*a7a4c16cSDiego Caballero 
1502bc4c3e9SNicolas Vasilache // We typically should not lower general shape cast operations into data
1512bc4c3e9SNicolas Vasilache // movement instructions, since the assumption is that these casts are
1522bc4c3e9SNicolas Vasilache // optimized away during progressive lowering. For completeness, however,
1532bc4c3e9SNicolas Vasilache // we fall back to a reference implementation that moves all elements
1542bc4c3e9SNicolas Vasilache // into the right place if we get here.
1552bc4c3e9SNicolas Vasilache class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
1562bc4c3e9SNicolas Vasilache public:
1572bc4c3e9SNicolas Vasilache   using OpRewritePattern::OpRewritePattern;
1582bc4c3e9SNicolas Vasilache 
1592bc4c3e9SNicolas Vasilache   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
1602bc4c3e9SNicolas Vasilache                                 PatternRewriter &rewriter) const override {
1612bc4c3e9SNicolas Vasilache     Location loc = op.getLoc();
1622bc4c3e9SNicolas Vasilache     auto sourceVectorType = op.getSourceVectorType();
1632bc4c3e9SNicolas Vasilache     auto resultVectorType = op.getResultVectorType();
1642bc4c3e9SNicolas Vasilache 
1658dffb71cSBenjamin Maxwell     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
1668dffb71cSBenjamin Maxwell       return failure();
1678dffb71cSBenjamin Maxwell 
168*a7a4c16cSDiego Caballero     // Special case for n-D / 1-D lowerings with better implementations.
1692bc4c3e9SNicolas Vasilache     int64_t srcRank = sourceVectorType.getRank();
1702bc4c3e9SNicolas Vasilache     int64_t resRank = resultVectorType.getRank();
171*a7a4c16cSDiego Caballero     if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
1722bc4c3e9SNicolas Vasilache       return failure();
1732bc4c3e9SNicolas Vasilache 
1742bc4c3e9SNicolas Vasilache     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
1752bc4c3e9SNicolas Vasilache     // extract/insert chains.
1762bc4c3e9SNicolas Vasilache     int64_t numElts = 1;
1772bc4c3e9SNicolas Vasilache     for (int64_t r = 0; r < srcRank; r++)
1782bc4c3e9SNicolas Vasilache       numElts *= sourceVectorType.getDimSize(r);
1792bc4c3e9SNicolas Vasilache     // Replace with data movement operations:
1802bc4c3e9SNicolas Vasilache     //    x[0,0,0] = y[0,0]
1812bc4c3e9SNicolas Vasilache     //    x[0,0,1] = y[0,1]
1822bc4c3e9SNicolas Vasilache     //    x[0,1,0] = y[0,2]
1832bc4c3e9SNicolas Vasilache     // etc., incrementing the two index vectors "row-major"
1842bc4c3e9SNicolas Vasilache     // within the source and result shape.
185*a7a4c16cSDiego Caballero     SmallVector<int64_t> srcIdx(srcRank, 0);
186*a7a4c16cSDiego Caballero     SmallVector<int64_t> resIdx(resRank, 0);
1872bc4c3e9SNicolas Vasilache     Value result = rewriter.create<arith::ConstantOp>(
1882bc4c3e9SNicolas Vasilache         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
1892bc4c3e9SNicolas Vasilache     for (int64_t i = 0; i < numElts; i++) {
1902bc4c3e9SNicolas Vasilache       if (i != 0) {
191*a7a4c16cSDiego Caballero         incIdx(srcIdx, sourceVectorType);
192*a7a4c16cSDiego Caballero         incIdx(resIdx, resultVectorType);
1932bc4c3e9SNicolas Vasilache       }
1940935c055SDiego Caballero 
1950935c055SDiego Caballero       Value extract;
1960935c055SDiego Caballero       if (srcRank == 0) {
1970935c055SDiego Caballero         // 0-D vector special case
1980935c055SDiego Caballero         assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
1990935c055SDiego Caballero         extract = rewriter.create<vector::ExtractElementOp>(
2000935c055SDiego Caballero             loc, op.getSourceVectorType().getElementType(), op.getSource());
2010935c055SDiego Caballero       } else {
2020935c055SDiego Caballero         extract =
2030935c055SDiego Caballero             rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
2040935c055SDiego Caballero       }
2050935c055SDiego Caballero 
2060935c055SDiego Caballero       if (resRank == 0) {
2070935c055SDiego Caballero         // 0-D vector special case
2080935c055SDiego Caballero         assert(resIdx.empty() && "Unexpected indices for 0-D vector");
2090935c055SDiego Caballero         result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
2100935c055SDiego Caballero       } else {
2110935c055SDiego Caballero         result =
2120935c055SDiego Caballero             rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
2130935c055SDiego Caballero       }
2142bc4c3e9SNicolas Vasilache     }
2152bc4c3e9SNicolas Vasilache     rewriter.replaceOp(op, result);
2162bc4c3e9SNicolas Vasilache     return success();
2172bc4c3e9SNicolas Vasilache   }
2188dffb71cSBenjamin Maxwell };
2192bc4c3e9SNicolas Vasilache 
2208dffb71cSBenjamin Maxwell /// A shape_cast lowering for scalable vectors with a single trailing scalable
2218dffb71cSBenjamin Maxwell /// dimension. This is similar to the general shape_cast lowering but makes use
2228dffb71cSBenjamin Maxwell /// of vector.scalable.insert and vector.scalable.extract to move elements a
2238dffb71cSBenjamin Maxwell /// subvector at a time.
2248dffb71cSBenjamin Maxwell ///
2258dffb71cSBenjamin Maxwell /// E.g.:
2268dffb71cSBenjamin Maxwell /// ```
2278dffb71cSBenjamin Maxwell /// // Flatten scalable vector
2288dffb71cSBenjamin Maxwell /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32>
2298dffb71cSBenjamin Maxwell /// ```
2308dffb71cSBenjamin Maxwell /// is rewritten to:
2318dffb71cSBenjamin Maxwell /// ```
2328dffb71cSBenjamin Maxwell /// // Flatten scalable vector
2338dffb71cSBenjamin Maxwell /// %c = arith.constant dense<0> : vector<[8]xi32>
2349816edc9SCullen Rhodes /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
2358dffb71cSBenjamin Maxwell /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32>
2369816edc9SCullen Rhodes /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32>
2378dffb71cSBenjamin Maxwell /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32>
2388dffb71cSBenjamin Maxwell /// ```
2398dffb71cSBenjamin Maxwell /// or:
2408dffb71cSBenjamin Maxwell /// ```
2418dffb71cSBenjamin Maxwell /// // Un-flatten scalable vector
2428dffb71cSBenjamin Maxwell /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32>
2438dffb71cSBenjamin Maxwell /// ```
2448dffb71cSBenjamin Maxwell /// is rewritten to:
2458dffb71cSBenjamin Maxwell /// ```
2468dffb71cSBenjamin Maxwell /// // Un-flatten scalable vector
2478dffb71cSBenjamin Maxwell /// %c = arith.constant dense<0> : vector<2x1x[4]xi32>
2488dffb71cSBenjamin Maxwell /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32>
2498dffb71cSBenjamin Maxwell /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
2508dffb71cSBenjamin Maxwell /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32>
2518dffb71cSBenjamin Maxwell /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32>
2528dffb71cSBenjamin Maxwell /// ```
2538dffb71cSBenjamin Maxwell class ScalableShapeCastOpRewritePattern
2548dffb71cSBenjamin Maxwell     : public OpRewritePattern<vector::ShapeCastOp> {
2558dffb71cSBenjamin Maxwell public:
2568dffb71cSBenjamin Maxwell   using OpRewritePattern::OpRewritePattern;
2578dffb71cSBenjamin Maxwell 
2588dffb71cSBenjamin Maxwell   LogicalResult matchAndRewrite(vector::ShapeCastOp op,
2598dffb71cSBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
2608dffb71cSBenjamin Maxwell 
2618dffb71cSBenjamin Maxwell     Location loc = op.getLoc();
2628dffb71cSBenjamin Maxwell     auto sourceVectorType = op.getSourceVectorType();
2638dffb71cSBenjamin Maxwell     auto resultVectorType = op.getResultVectorType();
2648dffb71cSBenjamin Maxwell     auto srcRank = sourceVectorType.getRank();
2658dffb71cSBenjamin Maxwell     auto resRank = resultVectorType.getRank();
2668dffb71cSBenjamin Maxwell 
2678dffb71cSBenjamin Maxwell     // This can only lower shape_casts where both the source and result types
2688dffb71cSBenjamin Maxwell     // have a single trailing scalable dimension. This is because there are no
2698dffb71cSBenjamin Maxwell     // legal representation of other scalable types in LLVM (and likely won't be
2708dffb71cSBenjamin Maxwell     // soon). There are also (currently) no operations that can index or extract
271*a7a4c16cSDiego Caballero     // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
2728dffb71cSBenjamin Maxwell     if (!isTrailingDimScalable(sourceVectorType) ||
2738dffb71cSBenjamin Maxwell         !isTrailingDimScalable(resultVectorType)) {
2748dffb71cSBenjamin Maxwell       return failure();
2752bc4c3e9SNicolas Vasilache     }
2768dffb71cSBenjamin Maxwell 
2778dffb71cSBenjamin Maxwell     // The sizes of the trailing dimension of the source and result vectors, the
2788dffb71cSBenjamin Maxwell     // size of subvector to move, and the number of elements in the vectors.
2798dffb71cSBenjamin Maxwell     // These are "min" sizes as they are the size when vscale == 1.
2808dffb71cSBenjamin Maxwell     auto minSourceTrailingSize = sourceVectorType.getShape().back();
2818dffb71cSBenjamin Maxwell     auto minResultTrailingSize = resultVectorType.getShape().back();
2828dffb71cSBenjamin Maxwell     auto minExtractionSize =
2838dffb71cSBenjamin Maxwell         std::min(minSourceTrailingSize, minResultTrailingSize);
2848dffb71cSBenjamin Maxwell     int64_t minNumElts = 1;
2858dffb71cSBenjamin Maxwell     for (auto size : sourceVectorType.getShape())
2868dffb71cSBenjamin Maxwell       minNumElts *= size;
2878dffb71cSBenjamin Maxwell 
2888dffb71cSBenjamin Maxwell     // The subvector type to move from the source to the result. Note that this
2898dffb71cSBenjamin Maxwell     // is a scalable vector. This rewrite will generate code in terms of the
2908dffb71cSBenjamin Maxwell     // "min" size (vscale == 1 case), that scales to any vscale.
2918dffb71cSBenjamin Maxwell     auto extractionVectorType = VectorType::get(
2928dffb71cSBenjamin Maxwell         {minExtractionSize}, sourceVectorType.getElementType(), {true});
2938dffb71cSBenjamin Maxwell 
2948dffb71cSBenjamin Maxwell     Value result = rewriter.create<arith::ConstantOp>(
2958dffb71cSBenjamin Maxwell         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
2968dffb71cSBenjamin Maxwell 
297*a7a4c16cSDiego Caballero     SmallVector<int64_t> srcIdx(srcRank, 0);
298*a7a4c16cSDiego Caballero     SmallVector<int64_t> resIdx(resRank, 0);
2998dffb71cSBenjamin Maxwell 
3008dffb71cSBenjamin Maxwell     // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils)
3018dffb71cSBenjamin Maxwell     // once D150000 lands.
3028dffb71cSBenjamin Maxwell     Value currentResultScalableVector;
3038dffb71cSBenjamin Maxwell     Value currentSourceScalableVector;
3048dffb71cSBenjamin Maxwell     for (int64_t i = 0; i < minNumElts; i += minExtractionSize) {
3058dffb71cSBenjamin Maxwell       // 1. Extract a scalable subvector from the source vector.
3068dffb71cSBenjamin Maxwell       if (!currentSourceScalableVector) {
3078dffb71cSBenjamin Maxwell         if (srcRank != 1) {
3088dffb71cSBenjamin Maxwell           currentSourceScalableVector = rewriter.create<vector::ExtractOp>(
3098dffb71cSBenjamin Maxwell               loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back());
3108dffb71cSBenjamin Maxwell         } else {
3118dffb71cSBenjamin Maxwell           currentSourceScalableVector = op.getSource();
3128dffb71cSBenjamin Maxwell         }
3138dffb71cSBenjamin Maxwell       }
3148dffb71cSBenjamin Maxwell       Value sourceSubVector = currentSourceScalableVector;
3158dffb71cSBenjamin Maxwell       if (minExtractionSize < minSourceTrailingSize) {
3168dffb71cSBenjamin Maxwell         sourceSubVector = rewriter.create<vector::ScalableExtractOp>(
3178dffb71cSBenjamin Maxwell             loc, extractionVectorType, sourceSubVector, srcIdx.back());
3188dffb71cSBenjamin Maxwell       }
3198dffb71cSBenjamin Maxwell 
3208dffb71cSBenjamin Maxwell       // 2. Insert the scalable subvector into the result vector.
3218dffb71cSBenjamin Maxwell       if (!currentResultScalableVector) {
3228dffb71cSBenjamin Maxwell         if (minExtractionSize == minResultTrailingSize) {
3238dffb71cSBenjamin Maxwell           currentResultScalableVector = sourceSubVector;
3248dffb71cSBenjamin Maxwell         } else if (resRank != 1) {
3258dffb71cSBenjamin Maxwell           currentResultScalableVector = rewriter.create<vector::ExtractOp>(
3268dffb71cSBenjamin Maxwell               loc, result, llvm::ArrayRef(resIdx).drop_back());
3278dffb71cSBenjamin Maxwell         } else {
3288dffb71cSBenjamin Maxwell           currentResultScalableVector = result;
3298dffb71cSBenjamin Maxwell         }
3308dffb71cSBenjamin Maxwell       }
3318dffb71cSBenjamin Maxwell       if (minExtractionSize < minResultTrailingSize) {
3328dffb71cSBenjamin Maxwell         currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>(
3338dffb71cSBenjamin Maxwell             loc, sourceSubVector, currentResultScalableVector, resIdx.back());
3348dffb71cSBenjamin Maxwell       }
3358dffb71cSBenjamin Maxwell 
3368dffb71cSBenjamin Maxwell       // 3. Update the source and result scalable vectors if needed.
3378dffb71cSBenjamin Maxwell       if (resIdx.back() + minExtractionSize >= minResultTrailingSize &&
3388dffb71cSBenjamin Maxwell           currentResultScalableVector != result) {
3398dffb71cSBenjamin Maxwell         // Finished row of result. Insert complete scalable vector into result
3408dffb71cSBenjamin Maxwell         // (n-D) vector.
3418dffb71cSBenjamin Maxwell         result = rewriter.create<vector::InsertOp>(
3428dffb71cSBenjamin Maxwell             loc, currentResultScalableVector, result,
3438dffb71cSBenjamin Maxwell             llvm::ArrayRef(resIdx).drop_back());
3448dffb71cSBenjamin Maxwell         currentResultScalableVector = {};
3458dffb71cSBenjamin Maxwell       }
3468dffb71cSBenjamin Maxwell       if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) {
3478dffb71cSBenjamin Maxwell         // Finished row of source.
3488dffb71cSBenjamin Maxwell         currentSourceScalableVector = {};
3498dffb71cSBenjamin Maxwell       }
3508dffb71cSBenjamin Maxwell 
3518dffb71cSBenjamin Maxwell       // 4. Increment the insert/extract indices, stepping by minExtractionSize
3528dffb71cSBenjamin Maxwell       // for the trailing dimensions.
353*a7a4c16cSDiego Caballero       incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
354*a7a4c16cSDiego Caballero       incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
3558dffb71cSBenjamin Maxwell     }
3568dffb71cSBenjamin Maxwell 
3578dffb71cSBenjamin Maxwell     rewriter.replaceOp(op, result);
3588dffb71cSBenjamin Maxwell     return success();
3598dffb71cSBenjamin Maxwell   }
3608dffb71cSBenjamin Maxwell 
3618dffb71cSBenjamin Maxwell   static bool isTrailingDimScalable(VectorType type) {
3628dffb71cSBenjamin Maxwell     return type.getRank() >= 1 && type.getScalableDims().back() &&
3638dffb71cSBenjamin Maxwell            !llvm::is_contained(type.getScalableDims().drop_back(), true);
3642bc4c3e9SNicolas Vasilache   }
3652bc4c3e9SNicolas Vasilache };
3668dffb71cSBenjamin Maxwell 
3672bc4c3e9SNicolas Vasilache } // namespace
3682bc4c3e9SNicolas Vasilache 
3692bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorShapeCastLoweringPatterns(
3702bc4c3e9SNicolas Vasilache     RewritePatternSet &patterns, PatternBenefit benefit) {
371*a7a4c16cSDiego Caballero   patterns.add<ShapeCastOpNDDownCastRewritePattern,
372*a7a4c16cSDiego Caballero                ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
3738dffb71cSBenjamin Maxwell                ScalableShapeCastOpRewritePattern>(patterns.getContext(),
3748dffb71cSBenjamin Maxwell                                                   benefit);
3752bc4c3e9SNicolas Vasilache }
376