1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewrite patterns for the permutation_map attribute of 10 // vector.transfer operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 18 #include "mlir/Interfaces/VectorInterfaces.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse 24 /// permutation based on the given indices. 25 static ArrayAttr 26 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, 27 const SmallVector<unsigned> &permutation) { 28 SmallVector<bool> newInBoundsValues(permutation.size()); 29 size_t index = 0; 30 for (unsigned pos : permutation) 31 newInBoundsValues[pos] = 32 cast<BoolAttr>(attr.getValue()[index++]).getValue(); 33 return builder.getBoolArrayAttr(newInBoundsValues); 34 } 35 36 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit 37 /// dimensions. 38 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, 39 int64_t addedRank) { 40 auto originalVecType = cast<VectorType>(vec.getType()); 41 SmallVector<int64_t> newShape(addedRank, 1); 42 newShape.append(originalVecType.getShape().begin(), 43 originalVecType.getShape().end()); 44 VectorType newVecType = 45 VectorType::get(newShape, originalVecType.getElementType()); 46 return builder.create<vector::BroadcastOp>(loc, newVecType, vec); 47 } 48 49 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit 50 /// dimensions. 51 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, 52 int64_t addedRank) { 53 Value broadcasted = extendVectorRank(builder, loc, vec, addedRank); 54 SmallVector<int64_t> permutation; 55 for (int64_t i = addedRank, 56 e = broadcasted.getType().cast<VectorType>().getRank(); 57 i < e; ++i) 58 permutation.push_back(i); 59 for (int64_t i = 0; i < addedRank; ++i) 60 permutation.push_back(i); 61 return builder.create<vector::TransposeOp>(loc, broadcasted, permutation); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // populateVectorTransferPermutationMapLoweringPatterns 66 //===----------------------------------------------------------------------===// 67 68 namespace { 69 /// Lower transfer_read op with permutation into a transfer_read with a 70 /// permutation map composed of leading zeros followed by a minor identiy + 71 /// vector.transpose op. 72 /// Ex: 73 /// vector.transfer_read ... 74 /// permutation_map: (d0, d1, d2) -> (0, d1) 75 /// into: 76 /// %v = vector.transfer_read ... 77 /// permutation_map: (d0, d1, d2) -> (d1, 0) 78 /// vector.transpose %v, [1, 0] 79 /// 80 /// vector.transfer_read ... 81 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) 82 /// into: 83 /// %v = vector.transfer_read ... 84 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) 85 /// vector.transpose %v, [0, 1, 3, 2, 4] 86 /// Note that an alternative is to transform it to linalg.transpose + 87 /// vector.transfer_read to do the transpose in memory instead. 88 struct TransferReadPermutationLowering 89 : public OpRewritePattern<vector::TransferReadOp> { 90 using OpRewritePattern::OpRewritePattern; 91 92 LogicalResult matchAndRewrite(vector::TransferReadOp op, 93 PatternRewriter &rewriter) const override { 94 // TODO: support 0-d corner case. 95 if (op.getTransferRank() == 0) 96 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 97 98 SmallVector<unsigned> permutation; 99 AffineMap map = op.getPermutationMap(); 100 if (map.getNumResults() == 0) 101 return rewriter.notifyMatchFailure(op, "0 result permutation map"); 102 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 103 return rewriter.notifyMatchFailure( 104 op, "map is not permutable to minor identity, apply another pattern"); 105 } 106 AffineMap permutationMap = 107 map.getPermutationMap(permutation, op.getContext()); 108 if (permutationMap.isIdentity()) 109 return rewriter.notifyMatchFailure(op, "map is not identity"); 110 111 permutationMap = map.getPermutationMap(permutation, op.getContext()); 112 // Caluclate the map of the new read by applying the inverse permutation. 113 permutationMap = inversePermutation(permutationMap); 114 AffineMap newMap = permutationMap.compose(map); 115 // Apply the reverse transpose to deduce the type of the transfer_read. 116 ArrayRef<int64_t> originalShape = op.getVectorType().getShape(); 117 SmallVector<int64_t> newVectorShape(originalShape.size()); 118 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims(); 119 SmallVector<bool> newScalableDims(originalShape.size()); 120 for (const auto &pos : llvm::enumerate(permutation)) { 121 newVectorShape[pos.value()] = originalShape[pos.index()]; 122 newScalableDims[pos.value()] = originalScalableDims[pos.index()]; 123 } 124 125 // Transpose in_bounds attribute. 126 ArrayAttr newInBoundsAttr = 127 op.getInBounds() ? inverseTransposeInBoundsAttr( 128 rewriter, op.getInBounds().value(), permutation) 129 : ArrayAttr(); 130 131 // Generate new transfer_read operation. 132 VectorType newReadType = VectorType::get( 133 newVectorShape, op.getVectorType().getElementType(), newScalableDims); 134 Value newRead = rewriter.create<vector::TransferReadOp>( 135 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 136 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 137 newInBoundsAttr); 138 139 // Transpose result of transfer_read. 140 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); 141 rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead, 142 transposePerm); 143 return success(); 144 } 145 }; 146 147 /// Lower transfer_write op with permutation into a transfer_write with a 148 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) 149 /// Ex: 150 /// vector.transfer_write %v ... 151 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) 152 /// into: 153 /// %tmp = vector.transpose %v, [2, 0, 1] 154 /// vector.transfer_write %tmp ... 155 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) 156 /// 157 /// vector.transfer_write %v ... 158 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) 159 /// into: 160 /// %tmp = vector.transpose %v, [1, 0] 161 /// %v = vector.transfer_write %tmp ... 162 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) 163 struct TransferWritePermutationLowering 164 : public OpRewritePattern<vector::TransferWriteOp> { 165 using OpRewritePattern::OpRewritePattern; 166 167 LogicalResult matchAndRewrite(vector::TransferWriteOp op, 168 PatternRewriter &rewriter) const override { 169 // TODO: support 0-d corner case. 170 if (op.getTransferRank() == 0) 171 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 172 173 SmallVector<unsigned> permutation; 174 AffineMap map = op.getPermutationMap(); 175 if (map.isMinorIdentity()) 176 return rewriter.notifyMatchFailure(op, "map is already minor identity"); 177 178 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 179 return rewriter.notifyMatchFailure( 180 op, "map is not permutable to minor identity, apply another pattern"); 181 } 182 183 // Remove unused dims from the permutation map. E.g.: 184 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) 185 // comp = (d0, d1, d2) -> (d2, d0, d1) 186 auto comp = compressUnusedDims(map); 187 AffineMap permutationMap = inversePermutation(comp); 188 // Get positions of remaining result dims. 189 SmallVector<int64_t> indices; 190 llvm::transform(permutationMap.getResults(), std::back_inserter(indices), 191 [](AffineExpr expr) { 192 return dyn_cast<AffineDimExpr>(expr).getPosition(); 193 }); 194 195 // Transpose in_bounds attribute. 196 ArrayAttr newInBoundsAttr = 197 op.getInBounds() ? inverseTransposeInBoundsAttr( 198 rewriter, op.getInBounds().value(), permutation) 199 : ArrayAttr(); 200 201 // Generate new transfer_write operation. 202 Value newVec = rewriter.create<vector::TransposeOp>( 203 op.getLoc(), op.getVector(), indices); 204 auto newMap = AffineMap::getMinorIdentityMap( 205 map.getNumDims(), map.getNumResults(), rewriter.getContext()); 206 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 207 op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), 208 op.getMask(), newInBoundsAttr); 209 210 return success(); 211 } 212 }; 213 214 /// Convert a transfer.write op with a map which isn't the permutation of a 215 /// minor identity into a vector.broadcast + transfer_write with permutation of 216 /// minor identity map by adding unit dim on inner dimension. Ex: 217 /// ``` 218 /// vector.transfer_write %v 219 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : 220 /// vector<8x16xf32> 221 /// ``` 222 /// into: 223 /// ``` 224 /// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32> 225 /// vector.transfer_write %v1 226 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} : 227 /// vector<1x8x16xf32> 228 /// ``` 229 struct TransferWriteNonPermutationLowering 230 : public OpRewritePattern<vector::TransferWriteOp> { 231 using OpRewritePattern::OpRewritePattern; 232 233 LogicalResult matchAndRewrite(vector::TransferWriteOp op, 234 PatternRewriter &rewriter) const override { 235 // TODO: support 0-d corner case. 236 if (op.getTransferRank() == 0) 237 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 238 239 SmallVector<unsigned> permutation; 240 AffineMap map = op.getPermutationMap(); 241 if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 242 return rewriter.notifyMatchFailure( 243 op, 244 "map is already permutable to minor identity, apply another pattern"); 245 } 246 247 // Missing outer dimensions are allowed, find the most outer existing 248 // dimension then deduce the missing inner dimensions. 249 SmallVector<bool> foundDim(map.getNumDims(), false); 250 for (AffineExpr exp : map.getResults()) 251 foundDim[cast<AffineDimExpr>(exp).getPosition()] = true; 252 SmallVector<AffineExpr> exprs; 253 bool foundFirstDim = false; 254 SmallVector<int64_t> missingInnerDim; 255 for (size_t i = 0; i < foundDim.size(); i++) { 256 if (foundDim[i]) { 257 foundFirstDim = true; 258 continue; 259 } 260 if (!foundFirstDim) 261 continue; 262 // Once we found one outer dimension existing in the map keep track of all 263 // the missing dimensions after that. 264 missingInnerDim.push_back(i); 265 exprs.push_back(rewriter.getAffineDimExpr(i)); 266 } 267 // Vector: add unit dims at the beginning of the shape. 268 Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(), 269 missingInnerDim.size()); 270 // Mask: add unit dims at the end of the shape. 271 Value newMask; 272 if (op.getMask()) 273 newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(), 274 missingInnerDim.size()); 275 exprs.append(map.getResults().begin(), map.getResults().end()); 276 AffineMap newMap = 277 AffineMap::get(map.getNumDims(), 0, exprs, op.getContext()); 278 // All the new dimensions added are inbound. 279 SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true); 280 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) { 281 newInBoundsValues.push_back(op.isDimInBounds(i)); 282 } 283 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); 284 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 285 op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), 286 newMask, newInBoundsAttr); 287 return success(); 288 } 289 }; 290 291 /// Lower transfer_read op with broadcast in the leading dimensions into 292 /// transfer_read of lower rank + vector.broadcast. 293 /// Ex: vector.transfer_read ... 294 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) 295 /// into: 296 /// %v = vector.transfer_read ... 297 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) 298 /// vector.broadcast %v 299 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> { 300 using OpRewritePattern::OpRewritePattern; 301 302 LogicalResult matchAndRewrite(vector::TransferReadOp op, 303 PatternRewriter &rewriter) const override { 304 // TODO: support 0-d corner case. 305 if (op.getTransferRank() == 0) 306 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 307 308 AffineMap map = op.getPermutationMap(); 309 unsigned numLeadingBroadcast = 0; 310 for (auto expr : map.getResults()) { 311 auto dimExpr = dyn_cast<AffineConstantExpr>(expr); 312 if (!dimExpr || dimExpr.getValue() != 0) 313 break; 314 numLeadingBroadcast++; 315 } 316 // If there are no leading zeros in the map there is nothing to do. 317 if (numLeadingBroadcast == 0) 318 return rewriter.notifyMatchFailure(op, "no leading broadcasts in map"); 319 320 VectorType originalVecType = op.getVectorType(); 321 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; 322 // Calculate new map, vector type and masks without the leading zeros. 323 AffineMap newMap = AffineMap::get( 324 map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), 325 op.getContext()); 326 // Only remove the leading zeros if the rest of the map is a minor identity 327 // with broadasting. Otherwise we first want to permute the map. 328 if (!newMap.isMinorIdentityWithBroadcasting()) { 329 return rewriter.notifyMatchFailure( 330 op, "map is not a minor identity with broadcasting"); 331 } 332 333 // TODO: support zero-dimension vectors natively. See: 334 // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. 335 // In the meantime, lower these to a scalar load when they pop up. 336 if (reducedShapeRank == 0) { 337 Value newRead; 338 if (isa<TensorType>(op.getShapedType())) { 339 newRead = rewriter.create<tensor::ExtractOp>( 340 op.getLoc(), op.getSource(), op.getIndices()); 341 } else { 342 newRead = rewriter.create<memref::LoadOp>( 343 op.getLoc(), originalVecType.getElementType(), op.getSource(), 344 op.getIndices()); 345 } 346 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, 347 newRead); 348 return success(); 349 } 350 351 SmallVector<int64_t> newShape( 352 originalVecType.getShape().take_back(reducedShapeRank)); 353 SmallVector<bool> newScalableDims( 354 originalVecType.getScalableDims().take_back(reducedShapeRank)); 355 // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. 356 if (newShape.empty()) 357 return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d"); 358 359 VectorType newReadType = VectorType::get( 360 newShape, originalVecType.getElementType(), newScalableDims); 361 ArrayAttr newInBoundsAttr = 362 op.getInBounds() 363 ? rewriter.getArrayAttr( 364 op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) 365 : ArrayAttr(); 366 Value newRead = rewriter.create<vector::TransferReadOp>( 367 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 368 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 369 newInBoundsAttr); 370 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, 371 newRead); 372 return success(); 373 } 374 }; 375 376 } // namespace 377 378 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 379 RewritePatternSet &patterns, PatternBenefit benefit) { 380 patterns 381 .add<TransferReadPermutationLowering, TransferWritePermutationLowering, 382 TransferOpReduceRank, TransferWriteNonPermutationLowering>( 383 patterns.getContext(), benefit); 384 } 385 386 //===----------------------------------------------------------------------===// 387 // populateVectorTransferLoweringPatterns 388 //===----------------------------------------------------------------------===// 389 390 namespace { 391 /// Progressive lowering of transfer_read. This pattern supports lowering of 392 /// `vector.transfer_read` to a combination of `vector.load` and 393 /// `vector.broadcast` if all of the following hold: 394 /// - Stride of most minor memref dimension must be 1. 395 /// - Out-of-bounds masking is not required. 396 /// - If the memref's element type is a vector type then it coincides with the 397 /// result type. 398 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 399 struct TransferReadToVectorLoadLowering 400 : public OpRewritePattern<vector::TransferReadOp> { 401 TransferReadToVectorLoadLowering(MLIRContext *context, 402 std::optional<unsigned> maxRank, 403 PatternBenefit benefit = 1) 404 : OpRewritePattern<vector::TransferReadOp>(context, benefit), 405 maxTransferRank(maxRank) {} 406 407 LogicalResult matchAndRewrite(vector::TransferReadOp read, 408 PatternRewriter &rewriter) const override { 409 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { 410 return rewriter.notifyMatchFailure( 411 read, "vector type is greater than max transfer rank"); 412 } 413 414 SmallVector<unsigned> broadcastedDims; 415 // Permutations are handled by VectorToSCF or 416 // populateVectorTransferPermutationMapLoweringPatterns. 417 // We let the 0-d corner case pass-through as it is supported. 418 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( 419 &broadcastedDims)) 420 return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); 421 422 auto memRefType = dyn_cast<MemRefType>(read.getShapedType()); 423 if (!memRefType) 424 return rewriter.notifyMatchFailure(read, "not a memref source"); 425 426 // Non-unit strides are handled by VectorToSCF. 427 if (!isLastMemrefDimUnitStride(memRefType)) 428 return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); 429 430 // If there is broadcasting involved then we first load the unbroadcasted 431 // vector, and then broadcast it with `vector.broadcast`. 432 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 433 SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(), 434 vectorShape.end()); 435 for (unsigned i : broadcastedDims) 436 unbroadcastedVectorShape[i] = 1; 437 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith( 438 unbroadcastedVectorShape, read.getVectorType().getElementType()); 439 440 // `vector.load` supports vector types as memref's elements only when the 441 // resulting vector type is the same as the element type. 442 auto memrefElTy = memRefType.getElementType(); 443 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType) 444 return rewriter.notifyMatchFailure(read, "incompatible element type"); 445 446 // Otherwise, element types of the memref and the vector must match. 447 if (!isa<VectorType>(memrefElTy) && 448 memrefElTy != read.getVectorType().getElementType()) 449 return rewriter.notifyMatchFailure(read, "non-matching element type"); 450 451 // Out-of-bounds dims are handled by MaterializeTransferMask. 452 if (read.hasOutOfBoundsDim()) 453 return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); 454 455 // Create vector load op. 456 Operation *loadOp; 457 if (read.getMask()) { 458 if (read.getVectorType().getRank() != 1) 459 // vector.maskedload operates on 1-D vectors. 460 return rewriter.notifyMatchFailure( 461 read, "vector type is not rank 1, can't create masked load, needs " 462 "VectorToSCF"); 463 464 Value fill = rewriter.create<vector::SplatOp>( 465 read.getLoc(), unbroadcastedVectorType, read.getPadding()); 466 loadOp = rewriter.create<vector::MaskedLoadOp>( 467 read.getLoc(), unbroadcastedVectorType, read.getSource(), 468 read.getIndices(), read.getMask(), fill); 469 } else { 470 loadOp = rewriter.create<vector::LoadOp>( 471 read.getLoc(), unbroadcastedVectorType, read.getSource(), 472 read.getIndices()); 473 } 474 475 // Insert a broadcasting op if required. 476 if (!broadcastedDims.empty()) { 477 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 478 read, read.getVectorType(), loadOp->getResult(0)); 479 } else { 480 rewriter.replaceOp(read, loadOp->getResult(0)); 481 } 482 483 return success(); 484 } 485 486 std::optional<unsigned> maxTransferRank; 487 }; 488 489 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 490 // TODO: we shouldn't cross the vector/scalar domains just for this 491 // but atm we lack the infra to avoid it. Possible solutions include: 492 // - go directly to LLVM + bitcast 493 // - introduce a bitcast op and likely a new pointer dialect 494 // - let memref.load/store additionally support the 0-d vector case 495 // There are still deeper data layout issues lingering even in this 496 // trivial case (for architectures for which this matters). 497 struct VectorLoadToMemrefLoadLowering 498 : public OpRewritePattern<vector::LoadOp> { 499 using OpRewritePattern::OpRewritePattern; 500 501 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 502 PatternRewriter &rewriter) const override { 503 auto vecType = loadOp.getVectorType(); 504 if (vecType.getNumElements() != 1) 505 return rewriter.notifyMatchFailure(loadOp, "not a single element vector"); 506 507 auto memrefLoad = rewriter.create<memref::LoadOp>( 508 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); 509 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 510 memrefLoad); 511 return success(); 512 } 513 }; 514 515 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 516 struct VectorStoreToMemrefStoreLowering 517 : public OpRewritePattern<vector::StoreOp> { 518 using OpRewritePattern::OpRewritePattern; 519 520 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 521 PatternRewriter &rewriter) const override { 522 auto vecType = storeOp.getVectorType(); 523 if (vecType.getNumElements() != 1) 524 return rewriter.notifyMatchFailure(storeOp, "not single element vector"); 525 526 Value extracted; 527 if (vecType.getRank() == 0) { 528 // TODO: Unifiy once ExtractOp supports 0-d vectors. 529 extracted = rewriter.create<vector::ExtractElementOp>( 530 storeOp.getLoc(), storeOp.getValueToStore()); 531 } else { 532 SmallVector<int64_t> indices(vecType.getRank(), 0); 533 extracted = rewriter.create<vector::ExtractOp>( 534 storeOp.getLoc(), storeOp.getValueToStore(), indices); 535 } 536 537 rewriter.replaceOpWithNewOp<memref::StoreOp>( 538 storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); 539 return success(); 540 } 541 }; 542 543 /// Progressive lowering of transfer_write. This pattern supports lowering of 544 /// `vector.transfer_write` to `vector.store` if all of the following hold: 545 /// - Stride of most minor memref dimension must be 1. 546 /// - Out-of-bounds masking is not required. 547 /// - If the memref's element type is a vector type then it coincides with the 548 /// type of the written value. 549 /// - The permutation map is the minor identity map (neither permutation nor 550 /// broadcasting is allowed). 551 struct TransferWriteToVectorStoreLowering 552 : public OpRewritePattern<vector::TransferWriteOp> { 553 TransferWriteToVectorStoreLowering(MLIRContext *context, 554 std::optional<unsigned> maxRank, 555 PatternBenefit benefit = 1) 556 : OpRewritePattern<vector::TransferWriteOp>(context, benefit), 557 maxTransferRank(maxRank) {} 558 559 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 560 PatternRewriter &rewriter) const override { 561 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { 562 return rewriter.notifyMatchFailure( 563 write, "vector type is greater than max transfer rank"); 564 } 565 566 // Permutations are handled by VectorToSCF or 567 // populateVectorTransferPermutationMapLoweringPatterns. 568 if ( // pass-through for the 0-d corner case. 569 !write.getPermutationMap().isMinorIdentity()) 570 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 571 diag << "permutation map is not minor identity: " << write; 572 }); 573 574 auto memRefType = dyn_cast<MemRefType>(write.getShapedType()); 575 if (!memRefType) 576 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 577 diag << "not a memref type: " << write; 578 }); 579 580 // Non-unit strides are handled by VectorToSCF. 581 if (!isLastMemrefDimUnitStride(memRefType)) 582 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 583 diag << "most minor stride is not 1: " << write; 584 }); 585 586 // `vector.store` supports vector types as memref's elements only when the 587 // type of the vector value being written is the same as the element type. 588 auto memrefElTy = memRefType.getElementType(); 589 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType()) 590 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 591 diag << "elemental type mismatch: " << write; 592 }); 593 594 // Otherwise, element types of the memref and the vector must match. 595 if (!isa<VectorType>(memrefElTy) && 596 memrefElTy != write.getVectorType().getElementType()) 597 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 598 diag << "elemental type mismatch: " << write; 599 }); 600 601 // Out-of-bounds dims are handled by MaterializeTransferMask. 602 if (write.hasOutOfBoundsDim()) 603 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 604 diag << "out of bounds dim: " << write; 605 }); 606 if (write.getMask()) { 607 if (write.getVectorType().getRank() != 1) 608 // vector.maskedstore operates on 1-D vectors. 609 return rewriter.notifyMatchFailure( 610 write.getLoc(), [=](Diagnostic &diag) { 611 diag << "vector type is not rank 1, can't create masked store, " 612 "needs VectorToSCF: " 613 << write; 614 }); 615 616 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 617 write, write.getSource(), write.getIndices(), write.getMask(), 618 write.getVector()); 619 } else { 620 rewriter.replaceOpWithNewOp<vector::StoreOp>( 621 write, write.getVector(), write.getSource(), write.getIndices()); 622 } 623 return success(); 624 } 625 626 std::optional<unsigned> maxTransferRank; 627 }; 628 } // namespace 629 630 void mlir::vector::populateVectorTransferLoweringPatterns( 631 RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank, 632 PatternBenefit benefit) { 633 patterns.add<TransferReadToVectorLoadLowering, 634 TransferWriteToVectorStoreLowering>(patterns.getContext(), 635 maxTransferRank, benefit); 636 patterns 637 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 638 patterns.getContext(), benefit); 639 } 640