1 //===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to lower the 10 // 'vector.shape_cast' operation. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/Arith/Utils/Utils.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/Tensor/IR/Tensor.h" 21 #include "mlir/Dialect/Utils/IndexingUtils.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 25 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 26 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 27 #include "mlir/IR/BuiltinAttributeInterfaces.h" 28 #include "mlir/IR/BuiltinTypes.h" 29 #include "mlir/IR/ImplicitLocOpBuilder.h" 30 #include "mlir/IR/Location.h" 31 #include "mlir/IR/Matchers.h" 32 #include "mlir/IR/PatternMatch.h" 33 #include "mlir/IR/TypeUtilities.h" 34 #include "mlir/Interfaces/VectorInterfaces.h" 35 36 #define DEBUG_TYPE "vector-shape-cast-lowering" 37 38 using namespace mlir; 39 using namespace mlir::vector; 40 41 namespace { 42 /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D 43 /// vectors progressively on the way to target llvm.matrix intrinsics. 44 /// This iterates over the most major dimension of the 2-D vector and performs 45 /// rewrites into: 46 /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D 47 class ShapeCastOp2DDownCastRewritePattern 48 : public OpRewritePattern<vector::ShapeCastOp> { 49 public: 50 using OpRewritePattern::OpRewritePattern; 51 52 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 53 PatternRewriter &rewriter) const override { 54 auto sourceVectorType = op.getSourceVectorType(); 55 auto resultVectorType = op.getResultVectorType(); 56 57 if (sourceVectorType.isScalable() || resultVectorType.isScalable()) 58 return failure(); 59 60 if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) 61 return failure(); 62 63 auto loc = op.getLoc(); 64 Value desc = rewriter.create<arith::ConstantOp>( 65 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 66 unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; 67 for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { 68 Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i); 69 desc = rewriter.create<vector::InsertStridedSliceOp>( 70 loc, vec, desc, 71 /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); 72 } 73 rewriter.replaceOp(op, desc); 74 return success(); 75 } 76 }; 77 78 /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D 79 /// vectors progressively. 80 /// This iterates over the most major dimension of the 2-D vector and performs 81 /// rewrites into: 82 /// vector.extract_strided_slice from 1-D + vector.insert into 2-D 83 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. 84 class ShapeCastOp2DUpCastRewritePattern 85 : public OpRewritePattern<vector::ShapeCastOp> { 86 public: 87 using OpRewritePattern::OpRewritePattern; 88 89 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 90 PatternRewriter &rewriter) const override { 91 auto sourceVectorType = op.getSourceVectorType(); 92 auto resultVectorType = op.getResultVectorType(); 93 94 if (sourceVectorType.isScalable() || resultVectorType.isScalable()) 95 return failure(); 96 97 if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) 98 return failure(); 99 100 auto loc = op.getLoc(); 101 Value desc = rewriter.create<arith::ConstantOp>( 102 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 103 unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; 104 for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { 105 Value vec = rewriter.create<vector::ExtractStridedSliceOp>( 106 loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, 107 /*sizes=*/mostMinorVectorSize, 108 /*strides=*/1); 109 desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); 110 } 111 rewriter.replaceOp(op, desc); 112 return success(); 113 } 114 }; 115 116 static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp, 117 int dimIdx, int initialStep = 1) { 118 int step = initialStep; 119 for (int d = dimIdx; d >= 0; d--) { 120 idx[d] += step; 121 if (idx[d] >= tp.getDimSize(d)) { 122 idx[d] = 0; 123 step = 1; 124 } else { 125 break; 126 } 127 } 128 } 129 130 // We typically should not lower general shape cast operations into data 131 // movement instructions, since the assumption is that these casts are 132 // optimized away during progressive lowering. For completeness, however, 133 // we fall back to a reference implementation that moves all elements 134 // into the right place if we get here. 135 class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { 136 public: 137 using OpRewritePattern::OpRewritePattern; 138 139 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 140 PatternRewriter &rewriter) const override { 141 Location loc = op.getLoc(); 142 auto sourceVectorType = op.getSourceVectorType(); 143 auto resultVectorType = op.getResultVectorType(); 144 145 if (sourceVectorType.isScalable() || resultVectorType.isScalable()) 146 return failure(); 147 148 // Special case 2D / 1D lowerings with better implementations. 149 // TODO: make is ND / 1D to allow generic ND -> 1D -> MD. 150 int64_t srcRank = sourceVectorType.getRank(); 151 int64_t resRank = resultVectorType.getRank(); 152 if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) 153 return failure(); 154 155 // Generic ShapeCast lowering path goes all the way down to unrolled scalar 156 // extract/insert chains. 157 // TODO: consider evolving the semantics to only allow 1D source or dest and 158 // drop this potentially very expensive lowering. 159 // Compute number of elements involved in the reshape. 160 int64_t numElts = 1; 161 for (int64_t r = 0; r < srcRank; r++) 162 numElts *= sourceVectorType.getDimSize(r); 163 // Replace with data movement operations: 164 // x[0,0,0] = y[0,0] 165 // x[0,0,1] = y[0,1] 166 // x[0,1,0] = y[0,2] 167 // etc., incrementing the two index vectors "row-major" 168 // within the source and result shape. 169 SmallVector<int64_t> srcIdx(srcRank); 170 SmallVector<int64_t> resIdx(resRank); 171 Value result = rewriter.create<arith::ConstantOp>( 172 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 173 for (int64_t i = 0; i < numElts; i++) { 174 if (i != 0) { 175 incIdx(srcIdx, sourceVectorType, srcRank - 1); 176 incIdx(resIdx, resultVectorType, resRank - 1); 177 } 178 179 Value extract; 180 if (srcRank == 0) { 181 // 0-D vector special case 182 assert(srcIdx.empty() && "Unexpected indices for 0-D vector"); 183 extract = rewriter.create<vector::ExtractElementOp>( 184 loc, op.getSourceVectorType().getElementType(), op.getSource()); 185 } else { 186 extract = 187 rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); 188 } 189 190 if (resRank == 0) { 191 // 0-D vector special case 192 assert(resIdx.empty() && "Unexpected indices for 0-D vector"); 193 result = rewriter.create<vector::InsertElementOp>(loc, extract, result); 194 } else { 195 result = 196 rewriter.create<vector::InsertOp>(loc, extract, result, resIdx); 197 } 198 } 199 rewriter.replaceOp(op, result); 200 return success(); 201 } 202 }; 203 204 /// A shape_cast lowering for scalable vectors with a single trailing scalable 205 /// dimension. This is similar to the general shape_cast lowering but makes use 206 /// of vector.scalable.insert and vector.scalable.extract to move elements a 207 /// subvector at a time. 208 /// 209 /// E.g.: 210 /// ``` 211 /// // Flatten scalable vector 212 /// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> 213 /// ``` 214 /// is rewritten to: 215 /// ``` 216 /// // Flatten scalable vector 217 /// %c = arith.constant dense<0> : vector<[8]xi32> 218 /// %0 = vector.extract %arg0[0, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> 219 /// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32> 220 /// %2 = vector.extract %arg0[1, 0] : vector<[4]xi32> from vector<2x1x[4]xi32> 221 /// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32> 222 /// ``` 223 /// or: 224 /// ``` 225 /// // Un-flatten scalable vector 226 /// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> 227 /// ``` 228 /// is rewritten to: 229 /// ``` 230 /// // Un-flatten scalable vector 231 /// %c = arith.constant dense<0> : vector<2x1x[4]xi32> 232 /// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32> 233 /// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> 234 /// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32> 235 /// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> 236 /// ``` 237 class ScalableShapeCastOpRewritePattern 238 : public OpRewritePattern<vector::ShapeCastOp> { 239 public: 240 using OpRewritePattern::OpRewritePattern; 241 242 LogicalResult matchAndRewrite(vector::ShapeCastOp op, 243 PatternRewriter &rewriter) const override { 244 245 Location loc = op.getLoc(); 246 auto sourceVectorType = op.getSourceVectorType(); 247 auto resultVectorType = op.getResultVectorType(); 248 auto srcRank = sourceVectorType.getRank(); 249 auto resRank = resultVectorType.getRank(); 250 251 // This can only lower shape_casts where both the source and result types 252 // have a single trailing scalable dimension. This is because there are no 253 // legal representation of other scalable types in LLVM (and likely won't be 254 // soon). There are also (currently) no operations that can index or extract 255 // from >= 2D scalable vectors or scalable vectors of fixed vectors. 256 if (!isTrailingDimScalable(sourceVectorType) || 257 !isTrailingDimScalable(resultVectorType)) { 258 return failure(); 259 } 260 261 // The sizes of the trailing dimension of the source and result vectors, the 262 // size of subvector to move, and the number of elements in the vectors. 263 // These are "min" sizes as they are the size when vscale == 1. 264 auto minSourceTrailingSize = sourceVectorType.getShape().back(); 265 auto minResultTrailingSize = resultVectorType.getShape().back(); 266 auto minExtractionSize = 267 std::min(minSourceTrailingSize, minResultTrailingSize); 268 int64_t minNumElts = 1; 269 for (auto size : sourceVectorType.getShape()) 270 minNumElts *= size; 271 272 // The subvector type to move from the source to the result. Note that this 273 // is a scalable vector. This rewrite will generate code in terms of the 274 // "min" size (vscale == 1 case), that scales to any vscale. 275 auto extractionVectorType = VectorType::get( 276 {minExtractionSize}, sourceVectorType.getElementType(), {true}); 277 278 Value result = rewriter.create<arith::ConstantOp>( 279 loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); 280 281 SmallVector<int64_t> srcIdx(srcRank); 282 SmallVector<int64_t> resIdx(resRank); 283 284 // TODO: Try rewriting this with StaticTileOffsetRange (from IndexingUtils) 285 // once D150000 lands. 286 Value currentResultScalableVector; 287 Value currentSourceScalableVector; 288 for (int64_t i = 0; i < minNumElts; i += minExtractionSize) { 289 // 1. Extract a scalable subvector from the source vector. 290 if (!currentSourceScalableVector) { 291 if (srcRank != 1) { 292 currentSourceScalableVector = rewriter.create<vector::ExtractOp>( 293 loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); 294 } else { 295 currentSourceScalableVector = op.getSource(); 296 } 297 } 298 Value sourceSubVector = currentSourceScalableVector; 299 if (minExtractionSize < minSourceTrailingSize) { 300 sourceSubVector = rewriter.create<vector::ScalableExtractOp>( 301 loc, extractionVectorType, sourceSubVector, srcIdx.back()); 302 } 303 304 // 2. Insert the scalable subvector into the result vector. 305 if (!currentResultScalableVector) { 306 if (minExtractionSize == minResultTrailingSize) { 307 currentResultScalableVector = sourceSubVector; 308 } else if (resRank != 1) { 309 currentResultScalableVector = rewriter.create<vector::ExtractOp>( 310 loc, result, llvm::ArrayRef(resIdx).drop_back()); 311 } else { 312 currentResultScalableVector = result; 313 } 314 } 315 if (minExtractionSize < minResultTrailingSize) { 316 currentResultScalableVector = rewriter.create<vector::ScalableInsertOp>( 317 loc, sourceSubVector, currentResultScalableVector, resIdx.back()); 318 } 319 320 // 3. Update the source and result scalable vectors if needed. 321 if (resIdx.back() + minExtractionSize >= minResultTrailingSize && 322 currentResultScalableVector != result) { 323 // Finished row of result. Insert complete scalable vector into result 324 // (n-D) vector. 325 result = rewriter.create<vector::InsertOp>( 326 loc, currentResultScalableVector, result, 327 llvm::ArrayRef(resIdx).drop_back()); 328 currentResultScalableVector = {}; 329 } 330 if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) { 331 // Finished row of source. 332 currentSourceScalableVector = {}; 333 } 334 335 // 4. Increment the insert/extract indices, stepping by minExtractionSize 336 // for the trailing dimensions. 337 incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize); 338 incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize); 339 } 340 341 rewriter.replaceOp(op, result); 342 return success(); 343 } 344 345 static bool isTrailingDimScalable(VectorType type) { 346 return type.getRank() >= 1 && type.getScalableDims().back() && 347 !llvm::is_contained(type.getScalableDims().drop_back(), true); 348 } 349 }; 350 351 } // namespace 352 353 void mlir::vector::populateVectorShapeCastLoweringPatterns( 354 RewritePatternSet &patterns, PatternBenefit benefit) { 355 patterns.add<ShapeCastOp2DDownCastRewritePattern, 356 ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern, 357 ScalableShapeCastOpRewritePattern>(patterns.getContext(), 358 benefit); 359 } 360