1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.shape_cast' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Vector/IR/VectorOps.h" 17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 18 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 19 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/Location.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/IR/TypeUtilities.h" 24 25 #define DEBUG_TYPE "vector-shape-cast-lowering" 26 27 using namespace mlir; 28 using namespace mlir::vector; 29 30 /// Increments n-D `indices` by `step` starting from the innermost dimension. 31 static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType, 32 int step = 1) { 33 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) { 34 assert(indices[dim] < vecType.getDimSize(dim) && 35 "Indices are out of bound"); 36 indices[dim] += step; 37 if (indices[dim] < vecType.getDimSize(dim)) 38 break; 39 40 indices[dim] = 0; 41 step = 1; 42 } 43 } 44 45 namespace { 46 /// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D 47 /// vectors progressively. This iterates over the n-1 major dimensions of the 48 /// n-D vector and performs rewrites into: 49 /// vector.extract from n-D + vector.insert_strided_slice offset into 1-D 50 class ShapeCastOpNDDownCastRewritePattern 51 : public OpRewritePattern<vector::ShapeCastOp> { 52 public: 53 using OpRewritePattern::OpRewritePattern; 54 55 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 56 PatternRewriter &rewriter) const override { 57 auto sourceVectorType = op.getSourceVectorType(); 58 auto resultVectorType = op.getResultVectorType(); 59 if (sourceVectorType.isScalable() || resultVectorType.isScalable()) 60 return failure(); 61 62 int64_t srcRank = sourceVectorType.getRank(); 63 int64_t resRank = resultVectorType.getRank(); 64 if (srcRank < 2 || resRank != 1) 65 return failure(); 66 67 // Compute the number of 1-D vector elements involved in the reshape. 68 int64_t numElts = 1; 69 for (int64_t dim = 0; dim < srcRank - 1; ++dim) 70 numElts *= sourceVectorType.getDimSize(dim); 71 72 auto loc = op.getLoc(); 73 SmallVector<int64_t> srcIdx(srcRank - 1, 0); 74 SmallVector<int64_t> resIdx(resRank, 0); 75 int64_t extractSize = sourceVectorType.getShape().back(); 76 Value result = rewriter.create<arith::ConstantOp>( 77 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 78 79 // Compute the indices of each 1-D vector element of the source extraction 80 // and destination slice insertion and generate such instructions. 81 for (int64_t i = 0; i < numElts; ++i) { 82 if (i != 0) { 83 incIdx(srcIdx, sourceVectorType, /*step=*/1); 84 incIdx(resIdx, resultVectorType, /*step=*/extractSize); 85 } 86 87 Value extract = 88 rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); 89 result = rewriter.create<vector::InsertStridedSliceOp>( 90 loc, extract, result, 91 /*offsets=*/resIdx, /*strides=*/1); 92 } 93 94 rewriter.replaceOp(op, result); 95 return success(); 96 } 97 }; 98 99 /// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D 100 /// vectors progressively. This iterates over the n-1 major dimension of the n-D 101 /// vector and performs rewrites into: 102 /// vector.extract_strided_slice from 1-D + vector.insert into n-D 103 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 104 class ShapeCastOpNDUpCastRewritePattern 105 : public OpRewritePattern<vector::ShapeCastOp> { 106 public: 107 using OpRewritePattern::OpRewritePattern; 108 109 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 110 PatternRewriter &rewriter) const override { 111 auto sourceVectorType = op.getSourceVectorType(); 112 auto resultVectorType = op.getResultVectorType(); 113 if (sourceVectorType.isScalable() || resultVectorType.isScalable()) 114 return failure(); 115 116 int64_t srcRank = sourceVectorType.getRank(); 117 int64_t resRank = resultVectorType.getRank(); 118 if (srcRank != 1 || resRank < 2) 119 return failure(); 120 121 // Compute the number of 1-D vector elements involved in the reshape. 122 int64_t numElts = 1; 123 for (int64_t dim = 0; dim < resRank - 1; ++dim) 124 numElts *= resultVectorType.getDimSize(dim); 125 126 // Compute the indices of each 1-D vector element of the source slice 127 // extraction and destination insertion and generate such instructions. 128 auto loc = op.getLoc(); 129 SmallVector<int64_t> srcIdx(srcRank, 0); 130 SmallVector<int64_t> resIdx(resRank - 1, 0); 131 int64_t extractSize = resultVectorType.getShape().back(); 132 Value result = rewriter.create<arith::ConstantOp>( 133 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 134 for (int64_t i = 0; i < numElts; ++i) { 135 if (i != 0) { 136 incIdx(srcIdx, sourceVectorType, /*step=*/extractSize); 137 incIdx(resIdx, resultVectorType, /*step=*/1); 138 } 139 140 Value extract = rewriter.create<vector::ExtractStridedSliceOp>( 141 loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize, 142 /*strides=*/1); 143 result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx); 144 } 145 rewriter.replaceOp(op, result); 146 return success(); 147 } 148 }; 149 150 // We typically should not lower general shape cast operations into data 151 // movement instructions, since the assumption is that these casts are 152 // optimized away during progressive lowering. For completeness, however, 153 // we fall back to a reference implementation that moves all elements 154 // into the right place if we get here. 155 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 156 public: 157 using OpRewritePattern::OpRewritePattern; 158 159 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 160 PatternRewriter &rewriter) const override { 161 Location loc = op.getLoc(); 162 auto sourceVectorType = op.getSourceVectorType(); 163 auto resultVectorType = op.getResultVectorType(); 164 165 if (sourceVectorType.isScalable() || resultVectorType.isScalable()) 166 return failure(); 167 168 // Special case for n-D / 1-D lowerings with better implementations. 169 int64_t srcRank = sourceVectorType.getRank(); 170 int64_t resRank = resultVectorType.getRank(); 171 if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1)) 172 return failure(); 173 174 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 175 // extract/insert chains. 176 int64_t numElts = 1; 177 for (int64_t r = 0; r < srcRank; r++) 178 numElts *= sourceVectorType.getDimSize(r); 179 // Replace with data movement operations: 180 // x[0,0,0] = y[0,0] 181 // x[0,0,1] = y[0,1] 182 // x[0,1,0] = y[0,2] 183 // etc., incrementing the two index vectors "row-major" 184 // within the source and result shape. 185 SmallVector<int64_t> srcIdx(srcRank, 0); 186 SmallVector<int64_t> resIdx(resRank, 0); 187 Value result = rewriter.create<arith::ConstantOp>( 188 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 189 for (int64_t i = 0; i < numElts; i++) { 190 if (i != 0) { 191 incIdx(srcIdx, sourceVectorType); 192 incIdx(resIdx, resultVectorType); 193 } 194 195 Value extract; 196 if (srcRank == 0) { 197 // 0-D vector special case 198 assert(srcIdx.empty() && "Unexpected indices for 0-D vector"); 199 extract = rewriter.create<vector::ExtractElementOp>( 200 loc, op.getSourceVectorType().getElementType(), op.getSource()); 201 } else { 202 extract = 203 rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); 204 } 205 206 if (resRank == 0) { 207 // 0-D vector special case 208 assert(resIdx.empty() && "Unexpected indices for 0-D vector"); 209 result = rewriter.create<vector::InsertElementOp>(loc, extract, result); 210 } else { 211 result = 212 rewriter.create<vector::InsertOp>(loc, extract, result, resIdx); 213 } 214 } 215 rewriter.replaceOp(op, result); 216 return success(); 217 } 218 }; 219 220 /// A shape_cast lowering for scalable vectors with a single trailing scalable 221 /// dimension. This is similar to the general shape_cast lowering but makes use 222 /// of vector.scalable.insert and vector.scalable.extract to move elements a 223 /// subvector at a time. 224 /// 225 /// E.g.: 226 /// ``` 227 /// // Flatten scalable vector 228 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> 229 /// ``` 230 /// is rewritten to: 231 /// ``` 232 /// // Flatten scalable vector 233 /// %c = arith.constant dense<0> : vector<[8]xi32> 234 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> 235 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32> 236 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> 237 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32> 238 /// ``` 239 /// or: 240 /// ``` 241 /// // Un-flatten scalable vector 242 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> 243 /// ``` 244 /// is rewritten to: 245 /// ``` 246 /// // Un-flatten scalable vector 247 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32> 248 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32> 249 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> 250 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32> 251 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> 252 /// ``` 253 class ScalableShapeCastOpRewritePattern 254 : public OpRewritePattern<vector::ShapeCastOp> { 255 public: 256 using OpRewritePattern::OpRewritePattern; 257 258 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 259 PatternRewriter &rewriter) const override { 260 261 Location loc = op.getLoc(); 262 auto sourceVectorType = op.getSourceVectorType(); 263 auto resultVectorType = op.getResultVectorType(); 264 auto srcRank = sourceVectorType.getRank(); 265 auto resRank = resultVectorType.getRank(); 266 267 // This can only lower shape_casts where both the source and result types 268 // have a single trailing scalable dimension. This is because there are no 269 // legal representation of other scalable types in LLVM (and likely won't be 270 // soon). There are also (currently) no operations that can index or extract 271 // from >= 2-D scalable vectors or scalable vectors of fixed vectors. 272 if (!isTrailingDimScalable(sourceVectorType) || 273 !isTrailingDimScalable(resultVectorType)) { 274 return failure(); 275 } 276 277 // The sizes of the trailing dimension of the source and result vectors, the 278 // size of subvector to move, and the number of elements in the vectors. 279 // These are "min" sizes as they are the size when vscale == 1. 280 auto minSourceTrailingSize = sourceVectorType.getShape().back(); 281 auto minResultTrailingSize = resultVectorType.getShape().back(); 282 auto minExtractionSize = 283 std::min(minSourceTrailingSize, minResultTrailingSize); 284 int64_t minNumElts = 1; 285 for (auto size : sourceVectorType.getShape()) 286 minNumElts *= size; 287 288 // The subvector type to move from the source to the result. Note that this 289 // is a scalable vector. This rewrite will generate code in terms of the 290 // "min" size (vscale == 1 case), that scales to any vscale. 291 auto extractionVectorType = VectorType::get( 292 {minExtractionSize}, sourceVectorType.getElementType(), {true}); 293 294 Value result = rewriter.create<arith::ConstantOp>( 295 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 296 297 SmallVector<int64_t> srcIdx(srcRank, 0); 298 SmallVector<int64_t> resIdx(resRank, 0); 299 300 // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils) 301 // once D150000 lands. 302 Value currentResultScalableVector; 303 Value currentSourceScalableVector; 304 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) { 305 // 1. Extract a scalable subvector from the source vector. 306 if (!currentSourceScalableVector) { 307 if (srcRank != 1) { 308 currentSourceScalableVector = rewriter.create<vector::ExtractOp>( 309 loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); 310 } else { 311 currentSourceScalableVector = op.getSource(); 312 } 313 } 314 Value sourceSubVector = currentSourceScalableVector; 315 if (minExtractionSize < minSourceTrailingSize) { 316 sourceSubVector = rewriter.create<vector::ScalableExtractOp>( 317 loc, extractionVectorType, sourceSubVector, srcIdx.back()); 318 } 319 320 // 2. Insert the scalable subvector into the result vector. 321 if (!currentResultScalableVector) { 322 if (minExtractionSize == minResultTrailingSize) { 323 currentResultScalableVector = sourceSubVector; 324 } else if (resRank != 1) { 325 currentResultScalableVector = rewriter.create<vector::ExtractOp>( 326 loc, result, llvm::ArrayRef(resIdx).drop_back()); 327 } else { 328 currentResultScalableVector = result; 329 } 330 } 331 if (minExtractionSize < minResultTrailingSize) { 332 currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>( 333 loc, sourceSubVector, currentResultScalableVector, resIdx.back()); 334 } 335 336 // 3. Update the source and result scalable vectors if needed. 337 if (resIdx.back() + minExtractionSize >= minResultTrailingSize && 338 currentResultScalableVector != result) { 339 // Finished row of result. Insert complete scalable vector into result 340 // (n-D) vector. 341 result = rewriter.create<vector::InsertOp>( 342 loc, currentResultScalableVector, result, 343 llvm::ArrayRef(resIdx).drop_back()); 344 currentResultScalableVector = {}; 345 } 346 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) { 347 // Finished row of source. 348 currentSourceScalableVector = {}; 349 } 350 351 // 4. Increment the insert/extract indices, stepping by minExtractionSize 352 // for the trailing dimensions. 353 incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize); 354 incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize); 355 } 356 357 rewriter.replaceOp(op, result); 358 return success(); 359 } 360 361 static bool isTrailingDimScalable(VectorType type) { 362 return type.getRank() >= 1 && type.getScalableDims().back() && 363 !llvm::is_contained(type.getScalableDims().drop_back(), true); 364 } 365 }; 366 367 } // namespace 368 369 void mlir::vector::populateVectorShapeCastLoweringPatterns( 370 RewritePatternSet &patterns, PatternBenefit benefit) { 371 patterns.add<ShapeCastOpNDDownCastRewritePattern, 372 ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern, 373 ScalableShapeCastOpRewritePattern>(patterns.getContext(), 374 benefit); 375 } 376