12bc4c3e9SNicolas Vasilache //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===// 22bc4c3e9SNicolas Vasilache // 32bc4c3e9SNicolas Vasilache // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 42bc4c3e9SNicolas Vasilache // See https://llvm.org/LICENSE.txt for license information. 52bc4c3e9SNicolas Vasilache // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 62bc4c3e9SNicolas Vasilache // 72bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 82bc4c3e9SNicolas Vasilache // 92bc4c3e9SNicolas Vasilache // This file implements target-independent rewrites and utilities to lower the 102bc4c3e9SNicolas Vasilache // 'vector.shape_cast' operation. 112bc4c3e9SNicolas Vasilache // 122bc4c3e9SNicolas Vasilache //===----------------------------------------------------------------------===// 132bc4c3e9SNicolas Vasilache 142bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Arith/IR/Arith.h" 152bc4c3e9SNicolas Vasilache #include "mlir/Dialect/MemRef/IR/MemRef.h" 162bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/IR/VectorOps.h" 172bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 182bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 192bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 202bc4c3e9SNicolas Vasilache #include "mlir/IR/BuiltinTypes.h" 212bc4c3e9SNicolas Vasilache #include "mlir/IR/Location.h" 222bc4c3e9SNicolas Vasilache #include "mlir/IR/PatternMatch.h" 232bc4c3e9SNicolas Vasilache #include "mlir/IR/TypeUtilities.h" 242bc4c3e9SNicolas Vasilache 252bc4c3e9SNicolas Vasilache #define DEBUG_TYPE "vector-shape-cast-lowering" 262bc4c3e9SNicolas Vasilache 272bc4c3e9SNicolas Vasilache using namespace mlir; 282bc4c3e9SNicolas Vasilache using namespace mlir::vector; 292bc4c3e9SNicolas Vasilache 30*a7a4c16cSDiego Caballero /// Increments n-D `indices` by `step` starting from the innermost dimension. 31*a7a4c16cSDiego Caballero static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType, 32*a7a4c16cSDiego Caballero int step = 1) { 33*a7a4c16cSDiego Caballero for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) { 34*a7a4c16cSDiego Caballero assert(indices[dim] < vecType.getDimSize(dim) && 35*a7a4c16cSDiego Caballero "Indices are out of bound"); 36*a7a4c16cSDiego Caballero indices[dim] += step; 37*a7a4c16cSDiego Caballero if (indices[dim] < vecType.getDimSize(dim)) 388dffb71cSBenjamin Maxwell break; 39*a7a4c16cSDiego Caballero 40*a7a4c16cSDiego Caballero indices[dim] = 0; 41*a7a4c16cSDiego Caballero step = 1; 428dffb71cSBenjamin Maxwell } 438dffb71cSBenjamin Maxwell } 44*a7a4c16cSDiego Caballero 45*a7a4c16cSDiego Caballero namespace { 46*a7a4c16cSDiego Caballero /// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D 47*a7a4c16cSDiego Caballero /// vectors progressively. This iterates over the n-1 major dimensions of the 48*a7a4c16cSDiego Caballero /// n-D vector and performs rewrites into: 49*a7a4c16cSDiego Caballero /// vector.extract from n-D + vector.insert_strided_slice offset into 1-D 50*a7a4c16cSDiego Caballero class ShapeCastOpNDDownCastRewritePattern 51*a7a4c16cSDiego Caballero : public OpRewritePattern<vector::ShapeCastOp> { 52*a7a4c16cSDiego Caballero public: 53*a7a4c16cSDiego Caballero using OpRewritePattern::OpRewritePattern; 54*a7a4c16cSDiego Caballero 55*a7a4c16cSDiego Caballero LogicalResult matchAndRewrite(vector::ShapeCastOp op, 56*a7a4c16cSDiego Caballero PatternRewriter &rewriter) const override { 57*a7a4c16cSDiego Caballero auto sourceVectorType = op.getSourceVectorType(); 58*a7a4c16cSDiego Caballero auto resultVectorType = op.getResultVectorType(); 59*a7a4c16cSDiego Caballero if (sourceVectorType.isScalable() || resultVectorType.isScalable()) 60*a7a4c16cSDiego Caballero return failure(); 61*a7a4c16cSDiego Caballero 62*a7a4c16cSDiego Caballero int64_t srcRank = sourceVectorType.getRank(); 63*a7a4c16cSDiego Caballero int64_t resRank = resultVectorType.getRank(); 64*a7a4c16cSDiego Caballero if (srcRank < 2 || resRank != 1) 65*a7a4c16cSDiego Caballero return failure(); 66*a7a4c16cSDiego Caballero 67*a7a4c16cSDiego Caballero // Compute the number of 1-D vector elements involved in the reshape. 68*a7a4c16cSDiego Caballero int64_t numElts = 1; 69*a7a4c16cSDiego Caballero for (int64_t dim = 0; dim < srcRank - 1; ++dim) 70*a7a4c16cSDiego Caballero numElts *= sourceVectorType.getDimSize(dim); 71*a7a4c16cSDiego Caballero 72*a7a4c16cSDiego Caballero auto loc = op.getLoc(); 73*a7a4c16cSDiego Caballero SmallVector<int64_t> srcIdx(srcRank - 1, 0); 74*a7a4c16cSDiego Caballero SmallVector<int64_t> resIdx(resRank, 0); 75*a7a4c16cSDiego Caballero int64_t extractSize = sourceVectorType.getShape().back(); 76*a7a4c16cSDiego Caballero Value result = rewriter.create<arith::ConstantOp>( 77*a7a4c16cSDiego Caballero loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 78*a7a4c16cSDiego Caballero 79*a7a4c16cSDiego Caballero // Compute the indices of each 1-D vector element of the source extraction 80*a7a4c16cSDiego Caballero // and destination slice insertion and generate such instructions. 81*a7a4c16cSDiego Caballero for (int64_t i = 0; i < numElts; ++i) { 82*a7a4c16cSDiego Caballero if (i != 0) { 83*a7a4c16cSDiego Caballero incIdx(srcIdx, sourceVectorType, /*step=*/1); 84*a7a4c16cSDiego Caballero incIdx(resIdx, resultVectorType, /*step=*/extractSize); 858dffb71cSBenjamin Maxwell } 868dffb71cSBenjamin Maxwell 87*a7a4c16cSDiego Caballero Value extract = 88*a7a4c16cSDiego Caballero rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); 89*a7a4c16cSDiego Caballero result = rewriter.create<vector::InsertStridedSliceOp>( 90*a7a4c16cSDiego Caballero loc, extract, result, 91*a7a4c16cSDiego Caballero /*offsets=*/resIdx, /*strides=*/1); 92*a7a4c16cSDiego Caballero } 93*a7a4c16cSDiego Caballero 94*a7a4c16cSDiego Caballero rewriter.replaceOp(op, result); 95*a7a4c16cSDiego Caballero return success(); 96*a7a4c16cSDiego Caballero } 97*a7a4c16cSDiego Caballero }; 98*a7a4c16cSDiego Caballero 99*a7a4c16cSDiego Caballero /// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D 100*a7a4c16cSDiego Caballero /// vectors progressively. This iterates over the n-1 major dimension of the n-D 101*a7a4c16cSDiego Caballero /// vector and performs rewrites into: 102*a7a4c16cSDiego Caballero /// vector.extract_strided_slice from 1-D + vector.insert into n-D 103*a7a4c16cSDiego Caballero /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 104*a7a4c16cSDiego Caballero class ShapeCastOpNDUpCastRewritePattern 105*a7a4c16cSDiego Caballero : public OpRewritePattern<vector::ShapeCastOp> { 106*a7a4c16cSDiego Caballero public: 107*a7a4c16cSDiego Caballero using OpRewritePattern::OpRewritePattern; 108*a7a4c16cSDiego Caballero 109*a7a4c16cSDiego Caballero LogicalResult matchAndRewrite(vector::ShapeCastOp op, 110*a7a4c16cSDiego Caballero PatternRewriter &rewriter) const override { 111*a7a4c16cSDiego Caballero auto sourceVectorType = op.getSourceVectorType(); 112*a7a4c16cSDiego Caballero auto resultVectorType = op.getResultVectorType(); 113*a7a4c16cSDiego Caballero if (sourceVectorType.isScalable() || resultVectorType.isScalable()) 114*a7a4c16cSDiego Caballero return failure(); 115*a7a4c16cSDiego Caballero 116*a7a4c16cSDiego Caballero int64_t srcRank = sourceVectorType.getRank(); 117*a7a4c16cSDiego Caballero int64_t resRank = resultVectorType.getRank(); 118*a7a4c16cSDiego Caballero if (srcRank != 1 || resRank < 2) 119*a7a4c16cSDiego Caballero return failure(); 120*a7a4c16cSDiego Caballero 121*a7a4c16cSDiego Caballero // Compute the number of 1-D vector elements involved in the reshape. 122*a7a4c16cSDiego Caballero int64_t numElts = 1; 123*a7a4c16cSDiego Caballero for (int64_t dim = 0; dim < resRank - 1; ++dim) 124*a7a4c16cSDiego Caballero numElts *= resultVectorType.getDimSize(dim); 125*a7a4c16cSDiego Caballero 126*a7a4c16cSDiego Caballero // Compute the indices of each 1-D vector element of the source slice 127*a7a4c16cSDiego Caballero // extraction and destination insertion and generate such instructions. 128*a7a4c16cSDiego Caballero auto loc = op.getLoc(); 129*a7a4c16cSDiego Caballero SmallVector<int64_t> srcIdx(srcRank, 0); 130*a7a4c16cSDiego Caballero SmallVector<int64_t> resIdx(resRank - 1, 0); 131*a7a4c16cSDiego Caballero int64_t extractSize = resultVectorType.getShape().back(); 132*a7a4c16cSDiego Caballero Value result = rewriter.create<arith::ConstantOp>( 133*a7a4c16cSDiego Caballero loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 134*a7a4c16cSDiego Caballero for (int64_t i = 0; i < numElts; ++i) { 135*a7a4c16cSDiego Caballero if (i != 0) { 136*a7a4c16cSDiego Caballero incIdx(srcIdx, sourceVectorType, /*step=*/extractSize); 137*a7a4c16cSDiego Caballero incIdx(resIdx, resultVectorType, /*step=*/1); 138*a7a4c16cSDiego Caballero } 139*a7a4c16cSDiego Caballero 140*a7a4c16cSDiego Caballero Value extract = rewriter.create<vector::ExtractStridedSliceOp>( 141*a7a4c16cSDiego Caballero loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize, 142*a7a4c16cSDiego Caballero /*strides=*/1); 143*a7a4c16cSDiego Caballero result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx); 144*a7a4c16cSDiego Caballero } 145*a7a4c16cSDiego Caballero rewriter.replaceOp(op, result); 146*a7a4c16cSDiego Caballero return success(); 147*a7a4c16cSDiego Caballero } 148*a7a4c16cSDiego Caballero }; 149*a7a4c16cSDiego Caballero 1502bc4c3e9SNicolas Vasilache // We typically should not lower general shape cast operations into data 1512bc4c3e9SNicolas Vasilache // movement instructions, since the assumption is that these casts are 1522bc4c3e9SNicolas Vasilache // optimized away during progressive lowering. For completeness, however, 1532bc4c3e9SNicolas Vasilache // we fall back to a reference implementation that moves all elements 1542bc4c3e9SNicolas Vasilache // into the right place if we get here. 1552bc4c3e9SNicolas Vasilache class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 1562bc4c3e9SNicolas Vasilache public: 1572bc4c3e9SNicolas Vasilache using OpRewritePattern::OpRewritePattern; 1582bc4c3e9SNicolas Vasilache 1592bc4c3e9SNicolas Vasilache LogicalResult matchAndRewrite(vector::ShapeCastOp op, 1602bc4c3e9SNicolas Vasilache PatternRewriter &rewriter) const override { 1612bc4c3e9SNicolas Vasilache Location loc = op.getLoc(); 1622bc4c3e9SNicolas Vasilache auto sourceVectorType = op.getSourceVectorType(); 1632bc4c3e9SNicolas Vasilache auto resultVectorType = op.getResultVectorType(); 1642bc4c3e9SNicolas Vasilache 1658dffb71cSBenjamin Maxwell if (sourceVectorType.isScalable() || resultVectorType.isScalable()) 1668dffb71cSBenjamin Maxwell return failure(); 1678dffb71cSBenjamin Maxwell 168*a7a4c16cSDiego Caballero // Special case for n-D / 1-D lowerings with better implementations. 1692bc4c3e9SNicolas Vasilache int64_t srcRank = sourceVectorType.getRank(); 1702bc4c3e9SNicolas Vasilache int64_t resRank = resultVectorType.getRank(); 171*a7a4c16cSDiego Caballero if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1)) 1722bc4c3e9SNicolas Vasilache return failure(); 1732bc4c3e9SNicolas Vasilache 1742bc4c3e9SNicolas Vasilache // Generic ShapeCast lowering path goes all the way down to unrolled scalar 1752bc4c3e9SNicolas Vasilache // extract/insert chains. 1762bc4c3e9SNicolas Vasilache int64_t numElts = 1; 1772bc4c3e9SNicolas Vasilache for (int64_t r = 0; r < srcRank; r++) 1782bc4c3e9SNicolas Vasilache numElts *= sourceVectorType.getDimSize(r); 1792bc4c3e9SNicolas Vasilache // Replace with data movement operations: 1802bc4c3e9SNicolas Vasilache // x[0,0,0] = y[0,0] 1812bc4c3e9SNicolas Vasilache // x[0,0,1] = y[0,1] 1822bc4c3e9SNicolas Vasilache // x[0,1,0] = y[0,2] 1832bc4c3e9SNicolas Vasilache // etc., incrementing the two index vectors "row-major" 1842bc4c3e9SNicolas Vasilache // within the source and result shape. 185*a7a4c16cSDiego Caballero SmallVector<int64_t> srcIdx(srcRank, 0); 186*a7a4c16cSDiego Caballero SmallVector<int64_t> resIdx(resRank, 0); 1872bc4c3e9SNicolas Vasilache Value result = rewriter.create<arith::ConstantOp>( 1882bc4c3e9SNicolas Vasilache loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 1892bc4c3e9SNicolas Vasilache for (int64_t i = 0; i < numElts; i++) { 1902bc4c3e9SNicolas Vasilache if (i != 0) { 191*a7a4c16cSDiego Caballero incIdx(srcIdx, sourceVectorType); 192*a7a4c16cSDiego Caballero incIdx(resIdx, resultVectorType); 1932bc4c3e9SNicolas Vasilache } 1940935c055SDiego Caballero 1950935c055SDiego Caballero Value extract; 1960935c055SDiego Caballero if (srcRank == 0) { 1970935c055SDiego Caballero // 0-D vector special case 1980935c055SDiego Caballero assert(srcIdx.empty() && "Unexpected indices for 0-D vector"); 1990935c055SDiego Caballero extract = rewriter.create<vector::ExtractElementOp>( 2000935c055SDiego Caballero loc, op.getSourceVectorType().getElementType(), op.getSource()); 2010935c055SDiego Caballero } else { 2020935c055SDiego Caballero extract = 2030935c055SDiego Caballero rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); 2040935c055SDiego Caballero } 2050935c055SDiego Caballero 2060935c055SDiego Caballero if (resRank == 0) { 2070935c055SDiego Caballero // 0-D vector special case 2080935c055SDiego Caballero assert(resIdx.empty() && "Unexpected indices for 0-D vector"); 2090935c055SDiego Caballero result = rewriter.create<vector::InsertElementOp>(loc, extract, result); 2100935c055SDiego Caballero } else { 2110935c055SDiego Caballero result = 2120935c055SDiego Caballero rewriter.create<vector::InsertOp>(loc, extract, result, resIdx); 2130935c055SDiego Caballero } 2142bc4c3e9SNicolas Vasilache } 2152bc4c3e9SNicolas Vasilache rewriter.replaceOp(op, result); 2162bc4c3e9SNicolas Vasilache return success(); 2172bc4c3e9SNicolas Vasilache } 2188dffb71cSBenjamin Maxwell }; 2192bc4c3e9SNicolas Vasilache 2208dffb71cSBenjamin Maxwell /// A shape_cast lowering for scalable vectors with a single trailing scalable 2218dffb71cSBenjamin Maxwell /// dimension. This is similar to the general shape_cast lowering but makes use 2228dffb71cSBenjamin Maxwell /// of vector.scalable.insert and vector.scalable.extract to move elements a 2238dffb71cSBenjamin Maxwell /// subvector at a time. 2248dffb71cSBenjamin Maxwell /// 2258dffb71cSBenjamin Maxwell /// E.g.: 2268dffb71cSBenjamin Maxwell /// ``` 2278dffb71cSBenjamin Maxwell /// // Flatten scalable vector 2288dffb71cSBenjamin Maxwell /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> 2298dffb71cSBenjamin Maxwell /// ``` 2308dffb71cSBenjamin Maxwell /// is rewritten to: 2318dffb71cSBenjamin Maxwell /// ``` 2328dffb71cSBenjamin Maxwell /// // Flatten scalable vector 2338dffb71cSBenjamin Maxwell /// %c = arith.constant dense<0> : vector<[8]xi32> 2349816edc9SCullen Rhodes /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> 2358dffb71cSBenjamin Maxwell /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32> 2369816edc9SCullen Rhodes /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> 2378dffb71cSBenjamin Maxwell /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32> 2388dffb71cSBenjamin Maxwell /// ``` 2398dffb71cSBenjamin Maxwell /// or: 2408dffb71cSBenjamin Maxwell /// ``` 2418dffb71cSBenjamin Maxwell /// // Un-flatten scalable vector 2428dffb71cSBenjamin Maxwell /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> 2438dffb71cSBenjamin Maxwell /// ``` 2448dffb71cSBenjamin Maxwell /// is rewritten to: 2458dffb71cSBenjamin Maxwell /// ``` 2468dffb71cSBenjamin Maxwell /// // Un-flatten scalable vector 2478dffb71cSBenjamin Maxwell /// %c = arith.constant dense<0> : vector<2x1x[4]xi32> 2488dffb71cSBenjamin Maxwell /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32> 2498dffb71cSBenjamin Maxwell /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> 2508dffb71cSBenjamin Maxwell /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32> 2518dffb71cSBenjamin Maxwell /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> 2528dffb71cSBenjamin Maxwell /// ``` 2538dffb71cSBenjamin Maxwell class ScalableShapeCastOpRewritePattern 2548dffb71cSBenjamin Maxwell : public OpRewritePattern<vector::ShapeCastOp> { 2558dffb71cSBenjamin Maxwell public: 2568dffb71cSBenjamin Maxwell using OpRewritePattern::OpRewritePattern; 2578dffb71cSBenjamin Maxwell 2588dffb71cSBenjamin Maxwell LogicalResult matchAndRewrite(vector::ShapeCastOp op, 2598dffb71cSBenjamin Maxwell PatternRewriter &rewriter) const override { 2608dffb71cSBenjamin Maxwell 2618dffb71cSBenjamin Maxwell Location loc = op.getLoc(); 2628dffb71cSBenjamin Maxwell auto sourceVectorType = op.getSourceVectorType(); 2638dffb71cSBenjamin Maxwell auto resultVectorType = op.getResultVectorType(); 2648dffb71cSBenjamin Maxwell auto srcRank = sourceVectorType.getRank(); 2658dffb71cSBenjamin Maxwell auto resRank = resultVectorType.getRank(); 2668dffb71cSBenjamin Maxwell 2678dffb71cSBenjamin Maxwell // This can only lower shape_casts where both the source and result types 2688dffb71cSBenjamin Maxwell // have a single trailing scalable dimension. This is because there are no 2698dffb71cSBenjamin Maxwell // legal representation of other scalable types in LLVM (and likely won't be 2708dffb71cSBenjamin Maxwell // soon). There are also (currently) no operations that can index or extract 271*a7a4c16cSDiego Caballero // from >= 2-D scalable vectors or scalable vectors of fixed vectors. 2728dffb71cSBenjamin Maxwell if (!isTrailingDimScalable(sourceVectorType) || 2738dffb71cSBenjamin Maxwell !isTrailingDimScalable(resultVectorType)) { 2748dffb71cSBenjamin Maxwell return failure(); 2752bc4c3e9SNicolas Vasilache } 2768dffb71cSBenjamin Maxwell 2778dffb71cSBenjamin Maxwell // The sizes of the trailing dimension of the source and result vectors, the 2788dffb71cSBenjamin Maxwell // size of subvector to move, and the number of elements in the vectors. 2798dffb71cSBenjamin Maxwell // These are "min" sizes as they are the size when vscale == 1. 2808dffb71cSBenjamin Maxwell auto minSourceTrailingSize = sourceVectorType.getShape().back(); 2818dffb71cSBenjamin Maxwell auto minResultTrailingSize = resultVectorType.getShape().back(); 2828dffb71cSBenjamin Maxwell auto minExtractionSize = 2838dffb71cSBenjamin Maxwell std::min(minSourceTrailingSize, minResultTrailingSize); 2848dffb71cSBenjamin Maxwell int64_t minNumElts = 1; 2858dffb71cSBenjamin Maxwell for (auto size : sourceVectorType.getShape()) 2868dffb71cSBenjamin Maxwell minNumElts *= size; 2878dffb71cSBenjamin Maxwell 2888dffb71cSBenjamin Maxwell // The subvector type to move from the source to the result. Note that this 2898dffb71cSBenjamin Maxwell // is a scalable vector. This rewrite will generate code in terms of the 2908dffb71cSBenjamin Maxwell // "min" size (vscale == 1 case), that scales to any vscale. 2918dffb71cSBenjamin Maxwell auto extractionVectorType = VectorType::get( 2928dffb71cSBenjamin Maxwell {minExtractionSize}, sourceVectorType.getElementType(), {true}); 2938dffb71cSBenjamin Maxwell 2948dffb71cSBenjamin Maxwell Value result = rewriter.create<arith::ConstantOp>( 2958dffb71cSBenjamin Maxwell loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 2968dffb71cSBenjamin Maxwell 297*a7a4c16cSDiego Caballero SmallVector<int64_t> srcIdx(srcRank, 0); 298*a7a4c16cSDiego Caballero SmallVector<int64_t> resIdx(resRank, 0); 2998dffb71cSBenjamin Maxwell 3008dffb71cSBenjamin Maxwell // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils) 3018dffb71cSBenjamin Maxwell // once D150000 lands. 3028dffb71cSBenjamin Maxwell Value currentResultScalableVector; 3038dffb71cSBenjamin Maxwell Value currentSourceScalableVector; 3048dffb71cSBenjamin Maxwell for (int64_t i = 0; i < minNumElts; i += minExtractionSize) { 3058dffb71cSBenjamin Maxwell // 1. Extract a scalable subvector from the source vector. 3068dffb71cSBenjamin Maxwell if (!currentSourceScalableVector) { 3078dffb71cSBenjamin Maxwell if (srcRank != 1) { 3088dffb71cSBenjamin Maxwell currentSourceScalableVector = rewriter.create<vector::ExtractOp>( 3098dffb71cSBenjamin Maxwell loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); 3108dffb71cSBenjamin Maxwell } else { 3118dffb71cSBenjamin Maxwell currentSourceScalableVector = op.getSource(); 3128dffb71cSBenjamin Maxwell } 3138dffb71cSBenjamin Maxwell } 3148dffb71cSBenjamin Maxwell Value sourceSubVector = currentSourceScalableVector; 3158dffb71cSBenjamin Maxwell if (minExtractionSize < minSourceTrailingSize) { 3168dffb71cSBenjamin Maxwell sourceSubVector = rewriter.create<vector::ScalableExtractOp>( 3178dffb71cSBenjamin Maxwell loc, extractionVectorType, sourceSubVector, srcIdx.back()); 3188dffb71cSBenjamin Maxwell } 3198dffb71cSBenjamin Maxwell 3208dffb71cSBenjamin Maxwell // 2. Insert the scalable subvector into the result vector. 3218dffb71cSBenjamin Maxwell if (!currentResultScalableVector) { 3228dffb71cSBenjamin Maxwell if (minExtractionSize == minResultTrailingSize) { 3238dffb71cSBenjamin Maxwell currentResultScalableVector = sourceSubVector; 3248dffb71cSBenjamin Maxwell } else if (resRank != 1) { 3258dffb71cSBenjamin Maxwell currentResultScalableVector = rewriter.create<vector::ExtractOp>( 3268dffb71cSBenjamin Maxwell loc, result, llvm::ArrayRef(resIdx).drop_back()); 3278dffb71cSBenjamin Maxwell } else { 3288dffb71cSBenjamin Maxwell currentResultScalableVector = result; 3298dffb71cSBenjamin Maxwell } 3308dffb71cSBenjamin Maxwell } 3318dffb71cSBenjamin Maxwell if (minExtractionSize < minResultTrailingSize) { 3328dffb71cSBenjamin Maxwell currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>( 3338dffb71cSBenjamin Maxwell loc, sourceSubVector, currentResultScalableVector, resIdx.back()); 3348dffb71cSBenjamin Maxwell } 3358dffb71cSBenjamin Maxwell 3368dffb71cSBenjamin Maxwell // 3. Update the source and result scalable vectors if needed. 3378dffb71cSBenjamin Maxwell if (resIdx.back() + minExtractionSize >= minResultTrailingSize && 3388dffb71cSBenjamin Maxwell currentResultScalableVector != result) { 3398dffb71cSBenjamin Maxwell // Finished row of result. Insert complete scalable vector into result 3408dffb71cSBenjamin Maxwell // (n-D) vector. 3418dffb71cSBenjamin Maxwell result = rewriter.create<vector::InsertOp>( 3428dffb71cSBenjamin Maxwell loc, currentResultScalableVector, result, 3438dffb71cSBenjamin Maxwell llvm::ArrayRef(resIdx).drop_back()); 3448dffb71cSBenjamin Maxwell currentResultScalableVector = {}; 3458dffb71cSBenjamin Maxwell } 3468dffb71cSBenjamin Maxwell if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) { 3478dffb71cSBenjamin Maxwell // Finished row of source. 3488dffb71cSBenjamin Maxwell currentSourceScalableVector = {}; 3498dffb71cSBenjamin Maxwell } 3508dffb71cSBenjamin Maxwell 3518dffb71cSBenjamin Maxwell // 4. Increment the insert/extract indices, stepping by minExtractionSize 3528dffb71cSBenjamin Maxwell // for the trailing dimensions. 353*a7a4c16cSDiego Caballero incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize); 354*a7a4c16cSDiego Caballero incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize); 3558dffb71cSBenjamin Maxwell } 3568dffb71cSBenjamin Maxwell 3578dffb71cSBenjamin Maxwell rewriter.replaceOp(op, result); 3588dffb71cSBenjamin Maxwell return success(); 3598dffb71cSBenjamin Maxwell } 3608dffb71cSBenjamin Maxwell 3618dffb71cSBenjamin Maxwell static bool isTrailingDimScalable(VectorType type) { 3628dffb71cSBenjamin Maxwell return type.getRank() >= 1 && type.getScalableDims().back() && 3638dffb71cSBenjamin Maxwell !llvm::is_contained(type.getScalableDims().drop_back(), true); 3642bc4c3e9SNicolas Vasilache } 3652bc4c3e9SNicolas Vasilache }; 3668dffb71cSBenjamin Maxwell 3672bc4c3e9SNicolas Vasilache } // namespace 3682bc4c3e9SNicolas Vasilache 3692bc4c3e9SNicolas Vasilache void mlir::vector::populateVectorShapeCastLoweringPatterns( 3702bc4c3e9SNicolas Vasilache RewritePatternSet &patterns, PatternBenefit benefit) { 371*a7a4c16cSDiego Caballero patterns.add<ShapeCastOpNDDownCastRewritePattern, 372*a7a4c16cSDiego Caballero ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern, 3738dffb71cSBenjamin Maxwell ScalableShapeCastOpRewritePattern>(patterns.getContext(), 3748dffb71cSBenjamin Maxwell benefit); 3752bc4c3e9SNicolas Vasilache } 376