1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewrite patterns for the permutation_map attribute of 10 // vector.transfer operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 18 #include "mlir/Interfaces/VectorInterfaces.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse 24 /// permutation based on the given indices. 25 static ArrayAttr 26 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, 27 const SmallVector<unsigned> &permutation) { 28 SmallVector<bool> newInBoundsValues(permutation.size()); 29 size_t index = 0; 30 for (unsigned pos : permutation) 31 newInBoundsValues[pos] = 32 cast<BoolAttr>(attr.getValue()[index++]).getValue(); 33 return builder.getBoolArrayAttr(newInBoundsValues); 34 } 35 36 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit 37 /// dimensions. 38 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, 39 int64_t addedRank) { 40 auto originalVecType = cast<VectorType>(vec.getType()); 41 SmallVector<int64_t> newShape(addedRank, 1); 42 newShape.append(originalVecType.getShape().begin(), 43 originalVecType.getShape().end()); 44 VectorType newVecType = 45 VectorType::get(newShape, originalVecType.getElementType()); 46 return builder.create<vector::BroadcastOp>(loc, newVecType, vec); 47 } 48 49 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit 50 /// dimensions. 51 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, 52 int64_t addedRank) { 53 Value broadcasted = extendVectorRank(builder, loc, vec, addedRank); 54 SmallVector<int64_t> permutation; 55 for (int64_t i = addedRank, 56 e = broadcasted.getType().cast<VectorType>().getRank(); 57 i < e; ++i) 58 permutation.push_back(i); 59 for (int64_t i = 0; i < addedRank; ++i) 60 permutation.push_back(i); 61 return builder.create<vector::TransposeOp>(loc, broadcasted, permutation); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // populateVectorTransferPermutationMapLoweringPatterns 66 //===----------------------------------------------------------------------===// 67 68 namespace { 69 /// Lower transfer_read op with permutation into a transfer_read with a 70 /// permutation map composed of leading zeros followed by a minor identiy + 71 /// vector.transpose op. 72 /// Ex: 73 /// vector.transfer_read ... 74 /// permutation_map: (d0, d1, d2) -> (0, d1) 75 /// into: 76 /// %v = vector.transfer_read ... 77 /// permutation_map: (d0, d1, d2) -> (d1, 0) 78 /// vector.transpose %v, [1, 0] 79 /// 80 /// vector.transfer_read ... 81 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) 82 /// into: 83 /// %v = vector.transfer_read ... 84 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) 85 /// vector.transpose %v, [0, 1, 3, 2, 4] 86 /// Note that an alternative is to transform it to linalg.transpose + 87 /// vector.transfer_read to do the transpose in memory instead. 88 struct TransferReadPermutationLowering 89 : public OpRewritePattern<vector::TransferReadOp> { 90 using OpRewritePattern::OpRewritePattern; 91 92 LogicalResult matchAndRewrite(vector::TransferReadOp op, 93 PatternRewriter &rewriter) const override { 94 // TODO: support 0-d corner case. 95 if (op.getTransferRank() == 0) 96 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 97 98 SmallVector<unsigned> permutation; 99 AffineMap map = op.getPermutationMap(); 100 if (map.getNumResults() == 0) 101 return rewriter.notifyMatchFailure(op, "0 result permutation map"); 102 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 103 return rewriter.notifyMatchFailure( 104 op, "map is not permutable to minor identity, apply another pattern"); 105 } 106 AffineMap permutationMap = 107 map.getPermutationMap(permutation, op.getContext()); 108 if (permutationMap.isIdentity()) 109 return rewriter.notifyMatchFailure(op, "map is not identity"); 110 111 permutationMap = map.getPermutationMap(permutation, op.getContext()); 112 // Caluclate the map of the new read by applying the inverse permutation. 113 permutationMap = inversePermutation(permutationMap); 114 AffineMap newMap = permutationMap.compose(map); 115 // Apply the reverse transpose to deduce the type of the transfer_read. 116 ArrayRef<int64_t> originalShape = op.getVectorType().getShape(); 117 SmallVector<int64_t> newVectorShape(originalShape.size()); 118 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims(); 119 SmallVector<bool> newScalableDims(originalShape.size()); 120 for (const auto &pos : llvm::enumerate(permutation)) { 121 newVectorShape[pos.value()] = originalShape[pos.index()]; 122 newScalableDims[pos.value()] = originalScalableDims[pos.index()]; 123 } 124 125 // Transpose in_bounds attribute. 126 ArrayAttr newInBoundsAttr = 127 op.getInBounds() ? inverseTransposeInBoundsAttr( 128 rewriter, op.getInBounds().value(), permutation) 129 : ArrayAttr(); 130 131 // Generate new transfer_read operation. 132 VectorType newReadType = VectorType::get( 133 newVectorShape, op.getVectorType().getElementType(), newScalableDims); 134 Value newRead = rewriter.create<vector::TransferReadOp>( 135 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 136 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 137 newInBoundsAttr); 138 139 // Transpose result of transfer_read. 140 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); 141 rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead, 142 transposePerm); 143 return success(); 144 } 145 }; 146 147 /// Lower transfer_write op with permutation into a transfer_write with a 148 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) 149 /// Ex: 150 /// vector.transfer_write %v ... 151 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) 152 /// into: 153 /// %tmp = vector.transpose %v, [2, 0, 1] 154 /// vector.transfer_write %tmp ... 155 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) 156 /// 157 /// vector.transfer_write %v ... 158 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) 159 /// into: 160 /// %tmp = vector.transpose %v, [1, 0] 161 /// %v = vector.transfer_write %tmp ... 162 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) 163 struct TransferWritePermutationLowering 164 : public OpRewritePattern<vector::TransferWriteOp> { 165 using OpRewritePattern::OpRewritePattern; 166 167 LogicalResult matchAndRewrite(vector::TransferWriteOp op, 168 PatternRewriter &rewriter) const override { 169 // TODO: support 0-d corner case. 170 if (op.getTransferRank() == 0) 171 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 172 173 SmallVector<unsigned> permutation; 174 AffineMap map = op.getPermutationMap(); 175 if (map.isMinorIdentity()) 176 return rewriter.notifyMatchFailure(op, "map is already minor identity"); 177 178 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 179 return rewriter.notifyMatchFailure( 180 op, "map is not permutable to minor identity, apply another pattern"); 181 } 182 183 // Remove unused dims from the permutation map. E.g.: 184 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) 185 // comp = (d0, d1, d2) -> (d2, d0, d1) 186 auto comp = compressUnusedDims(map); 187 AffineMap permutationMap = inversePermutation(comp); 188 // Get positions of remaining result dims. 189 SmallVector<int64_t> indices; 190 llvm::transform(permutationMap.getResults(), std::back_inserter(indices), 191 [](AffineExpr expr) { 192 return expr.dyn_cast<AffineDimExpr>().getPosition(); 193 }); 194 195 // Transpose in_bounds attribute. 196 ArrayAttr newInBoundsAttr = 197 op.getInBounds() ? inverseTransposeInBoundsAttr( 198 rewriter, op.getInBounds().value(), permutation) 199 : ArrayAttr(); 200 201 // Generate new transfer_write operation. 202 Value newVec = rewriter.create<vector::TransposeOp>( 203 op.getLoc(), op.getVector(), indices); 204 auto newMap = AffineMap::getMinorIdentityMap( 205 map.getNumDims(), map.getNumResults(), rewriter.getContext()); 206 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 207 op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), 208 op.getMask(), newInBoundsAttr); 209 210 return success(); 211 } 212 }; 213 214 /// Convert a transfer.write op with a map which isn't the permutation of a 215 /// minor identity into a vector.broadcast + transfer_write with permutation of 216 /// minor identity map by adding unit dim on inner dimension. Ex: 217 /// ``` 218 /// vector.transfer_write %v 219 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : 220 /// vector<8x16xf32> 221 /// ``` 222 /// into: 223 /// ``` 224 /// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32> 225 /// vector.transfer_write %v1 226 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} : 227 /// vector<1x8x16xf32> 228 /// ``` 229 struct TransferWriteNonPermutationLowering 230 : public OpRewritePattern<vector::TransferWriteOp> { 231 using OpRewritePattern::OpRewritePattern; 232 233 LogicalResult matchAndRewrite(vector::TransferWriteOp op, 234 PatternRewriter &rewriter) const override { 235 // TODO: support 0-d corner case. 236 if (op.getTransferRank() == 0) 237 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 238 239 SmallVector<unsigned> permutation; 240 AffineMap map = op.getPermutationMap(); 241 if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 242 return rewriter.notifyMatchFailure( 243 op, 244 "map is already permutable to minor identity, apply another pattern"); 245 } 246 247 // Missing outer dimensions are allowed, find the most outer existing 248 // dimension then deduce the missing inner dimensions. 249 SmallVector<bool> foundDim(map.getNumDims(), false); 250 for (AffineExpr exp : map.getResults()) 251 foundDim[exp.cast<AffineDimExpr>().getPosition()] = true; 252 SmallVector<AffineExpr> exprs; 253 bool foundFirstDim = false; 254 SmallVector<int64_t> missingInnerDim; 255 for (size_t i = 0; i < foundDim.size(); i++) { 256 if (foundDim[i]) { 257 foundFirstDim = true; 258 continue; 259 } 260 if (!foundFirstDim) 261 continue; 262 // Once we found one outer dimension existing in the map keep track of all 263 // the missing dimensions after that. 264 missingInnerDim.push_back(i); 265 exprs.push_back(rewriter.getAffineDimExpr(i)); 266 } 267 // Vector: add unit dims at the beginning of the shape. 268 Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(), 269 missingInnerDim.size()); 270 // Mask: add unit dims at the end of the shape. 271 Value newMask; 272 if (op.getMask()) 273 newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(), 274 missingInnerDim.size()); 275 exprs.append(map.getResults().begin(), map.getResults().end()); 276 AffineMap newMap = 277 AffineMap::get(map.getNumDims(), 0, exprs, op.getContext()); 278 // All the new dimensions added are inbound. 279 SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true); 280 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) { 281 newInBoundsValues.push_back(op.isDimInBounds(i)); 282 } 283 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); 284 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 285 op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), 286 newMask, newInBoundsAttr); 287 return success(); 288 } 289 }; 290 291 /// Lower transfer_read op with broadcast in the leading dimensions into 292 /// transfer_read of lower rank + vector.broadcast. 293 /// Ex: vector.transfer_read ... 294 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) 295 /// into: 296 /// %v = vector.transfer_read ... 297 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) 298 /// vector.broadcast %v 299 struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> { 300 using OpRewritePattern::OpRewritePattern; 301 302 LogicalResult matchAndRewrite(vector::TransferReadOp op, 303 PatternRewriter &rewriter) const override { 304 // TODO: support 0-d corner case. 305 if (op.getTransferRank() == 0) 306 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 307 308 AffineMap map = op.getPermutationMap(); 309 unsigned numLeadingBroadcast = 0; 310 for (auto expr : map.getResults()) { 311 auto dimExpr = expr.dyn_cast<AffineConstantExpr>(); 312 if (!dimExpr || dimExpr.getValue() != 0) 313 break; 314 numLeadingBroadcast++; 315 } 316 // If there are no leading zeros in the map there is nothing to do. 317 if (numLeadingBroadcast == 0) 318 return rewriter.notifyMatchFailure(op, "no leading broadcasts in map"); 319 320 VectorType originalVecType = op.getVectorType(); 321 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; 322 // Calculate new map, vector type and masks without the leading zeros. 323 AffineMap newMap = AffineMap::get( 324 map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), 325 op.getContext()); 326 // Only remove the leading zeros if the rest of the map is a minor identity 327 // with broadasting. Otherwise we first want to permute the map. 328 if (!newMap.isMinorIdentityWithBroadcasting()) { 329 return rewriter.notifyMatchFailure( 330 op, "map is not a minor identity with broadcasting"); 331 } 332 333 // TODO: support zero-dimension vectors natively. See: 334 // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. 335 // In the meantime, lower these to a scalar load when they pop up. 336 if (reducedShapeRank == 0) { 337 Value newRead; 338 if (isa<TensorType>(op.getShapedType())) { 339 newRead = rewriter.create<tensor::ExtractOp>( 340 op.getLoc(), op.getSource(), op.getIndices()); 341 } else { 342 newRead = rewriter.create<memref::LoadOp>( 343 op.getLoc(), originalVecType.getElementType(), op.getSource(), 344 op.getIndices()); 345 } 346 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, 347 newRead); 348 return success(); 349 } 350 351 SmallVector<int64_t> newShape( 352 originalVecType.getShape().take_back(reducedShapeRank)); 353 SmallVector<bool> newScalableDims( 354 originalVecType.getScalableDims().take_back(reducedShapeRank)); 355 // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. 356 if (newShape.empty()) 357 return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d"); 358 359 VectorType newReadType = VectorType::get( 360 newShape, originalVecType.getElementType(), newScalableDims); 361 ArrayAttr newInBoundsAttr = 362 op.getInBounds() 363 ? rewriter.getArrayAttr( 364 op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) 365 : ArrayAttr(); 366 Value newRead = rewriter.create<vector::TransferReadOp>( 367 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 368 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 369 newInBoundsAttr); 370 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, 371 newRead); 372 return success(); 373 } 374 }; 375 376 } // namespace 377 378 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 379 RewritePatternSet &patterns, PatternBenefit benefit) { 380 patterns 381 .add<TransferReadPermutationLowering, TransferWritePermutationLowering, 382 TransferOpReduceRank, TransferWriteNonPermutationLowering>( 383 patterns.getContext(), benefit); 384 } 385 386 //===----------------------------------------------------------------------===// 387 // populateVectorTransferLoweringPatterns 388 //===----------------------------------------------------------------------===// 389 390 namespace { 391 /// Progressive lowering of transfer_read. This pattern supports lowering of 392 /// `vector.transfer_read` to a combination of `vector.load` and 393 /// `vector.broadcast` if all of the following hold: 394 /// - Stride of most minor memref dimension must be 1. 395 /// - Out-of-bounds masking is not required. 396 /// - If the memref's element type is a vector type then it coincides with the 397 /// result type. 398 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 399 struct TransferReadToVectorLoadLowering 400 : public OpRewritePattern<vector::TransferReadOp> { 401 TransferReadToVectorLoadLowering(MLIRContext *context, 402 std::optional<unsigned> maxRank, 403 PatternBenefit benefit = 1) 404 : OpRewritePattern<vector::TransferReadOp>(context, benefit), 405 maxTransferRank(maxRank) {} 406 407 LogicalResult matchAndRewrite(vector::TransferReadOp read, 408 PatternRewriter &rewriter) const override { 409 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { 410 return rewriter.notifyMatchFailure( 411 read, "vector type is greater than max transfer rank"); 412 } 413 414 SmallVector<unsigned> broadcastedDims; 415 // Permutations are handled by VectorToSCF or 416 // populateVectorTransferPermutationMapLoweringPatterns. 417 // We let the 0-d corner case pass-through as it is supported. 418 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( 419 &broadcastedDims)) 420 return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); 421 422 auto memRefType = dyn_cast<MemRefType>(read.getShapedType()); 423 if (!memRefType) 424 return rewriter.notifyMatchFailure(read, "not a memref source"); 425 426 // Non-unit strides are handled by VectorToSCF. 427 if (!isLastMemrefDimUnitStride(memRefType)) 428 return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); 429 430 // If there is broadcasting involved then we first load the unbroadcasted 431 // vector, and then broadcast it with `vector.broadcast`. 432 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 433 SmallVector<int64_t> unbroadcastedVectorShape(vectorShape.begin(), 434 vectorShape.end()); 435 for (unsigned i : broadcastedDims) 436 unbroadcastedVectorShape[i] = 1; 437 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith( 438 unbroadcastedVectorShape, read.getVectorType().getElementType()); 439 440 // `vector.load` supports vector types as memref's elements only when the 441 // resulting vector type is the same as the element type. 442 auto memrefElTy = memRefType.getElementType(); 443 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType) 444 return rewriter.notifyMatchFailure(read, "incompatible element type"); 445 446 // Otherwise, element types of the memref and the vector must match. 447 if (!isa<VectorType>(memrefElTy) && 448 memrefElTy != read.getVectorType().getElementType()) 449 return rewriter.notifyMatchFailure(read, "non-matching element type"); 450 451 // Out-of-bounds dims are handled by MaterializeTransferMask. 452 if (read.hasOutOfBoundsDim()) 453 return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); 454 455 // Create vector load op. 456 Operation *loadOp; 457 if (read.getMask()) { 458 Value fill = rewriter.create<vector::SplatOp>( 459 read.getLoc(), unbroadcastedVectorType, read.getPadding()); 460 loadOp = rewriter.create<vector::MaskedLoadOp>( 461 read.getLoc(), unbroadcastedVectorType, read.getSource(), 462 read.getIndices(), read.getMask(), fill); 463 } else { 464 loadOp = rewriter.create<vector::LoadOp>( 465 read.getLoc(), unbroadcastedVectorType, read.getSource(), 466 read.getIndices()); 467 } 468 469 // Insert a broadcasting op if required. 470 if (!broadcastedDims.empty()) { 471 rewriter.replaceOpWithNewOp<vector::BroadcastOp>( 472 read, read.getVectorType(), loadOp->getResult(0)); 473 } else { 474 rewriter.replaceOp(read, loadOp->getResult(0)); 475 } 476 477 return success(); 478 } 479 480 std::optional<unsigned> maxTransferRank; 481 }; 482 483 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 484 // TODO: we shouldn't cross the vector/scalar domains just for this 485 // but atm we lack the infra to avoid it. Possible solutions include: 486 // - go directly to LLVM + bitcast 487 // - introduce a bitcast op and likely a new pointer dialect 488 // - let memref.load/store additionally support the 0-d vector case 489 // There are still deeper data layout issues lingering even in this 490 // trivial case (for architectures for which this matters). 491 struct VectorLoadToMemrefLoadLowering 492 : public OpRewritePattern<vector::LoadOp> { 493 using OpRewritePattern::OpRewritePattern; 494 495 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 496 PatternRewriter &rewriter) const override { 497 auto vecType = loadOp.getVectorType(); 498 if (vecType.getNumElements() != 1) 499 return rewriter.notifyMatchFailure(loadOp, "not a single element vector"); 500 501 auto memrefLoad = rewriter.create<memref::LoadOp>( 502 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); 503 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 504 memrefLoad); 505 return success(); 506 } 507 }; 508 509 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 510 struct VectorStoreToMemrefStoreLowering 511 : public OpRewritePattern<vector::StoreOp> { 512 using OpRewritePattern::OpRewritePattern; 513 514 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 515 PatternRewriter &rewriter) const override { 516 auto vecType = storeOp.getVectorType(); 517 if (vecType.getNumElements() != 1) 518 return rewriter.notifyMatchFailure(storeOp, "not single element vector"); 519 520 Value extracted; 521 if (vecType.getRank() == 0) { 522 // TODO: Unifiy once ExtractOp supports 0-d vectors. 523 extracted = rewriter.create<vector::ExtractElementOp>( 524 storeOp.getLoc(), storeOp.getValueToStore()); 525 } else { 526 SmallVector<int64_t> indices(vecType.getRank(), 0); 527 extracted = rewriter.create<vector::ExtractOp>( 528 storeOp.getLoc(), storeOp.getValueToStore(), indices); 529 } 530 531 rewriter.replaceOpWithNewOp<memref::StoreOp>( 532 storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); 533 return success(); 534 } 535 }; 536 537 /// Progressive lowering of transfer_write. This pattern supports lowering of 538 /// `vector.transfer_write` to `vector.store` if all of the following hold: 539 /// - Stride of most minor memref dimension must be 1. 540 /// - Out-of-bounds masking is not required. 541 /// - If the memref's element type is a vector type then it coincides with the 542 /// type of the written value. 543 /// - The permutation map is the minor identity map (neither permutation nor 544 /// broadcasting is allowed). 545 struct TransferWriteToVectorStoreLowering 546 : public OpRewritePattern<vector::TransferWriteOp> { 547 TransferWriteToVectorStoreLowering(MLIRContext *context, 548 std::optional<unsigned> maxRank, 549 PatternBenefit benefit = 1) 550 : OpRewritePattern<vector::TransferWriteOp>(context, benefit), 551 maxTransferRank(maxRank) {} 552 553 LogicalResult matchAndRewrite(vector::TransferWriteOp write, 554 PatternRewriter &rewriter) const override { 555 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { 556 return rewriter.notifyMatchFailure( 557 write, "vector type is greater than max transfer rank"); 558 } 559 560 // Permutations are handled by VectorToSCF or 561 // populateVectorTransferPermutationMapLoweringPatterns. 562 if ( // pass-through for the 0-d corner case. 563 !write.getPermutationMap().isMinorIdentity()) 564 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 565 diag << "permutation map is not minor identity: " << write; 566 }); 567 568 auto memRefType = dyn_cast<MemRefType>(write.getShapedType()); 569 if (!memRefType) 570 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 571 diag << "not a memref type: " << write; 572 }); 573 574 // Non-unit strides are handled by VectorToSCF. 575 if (!isLastMemrefDimUnitStride(memRefType)) 576 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 577 diag << "most minor stride is not 1: " << write; 578 }); 579 580 // `vector.store` supports vector types as memref's elements only when the 581 // type of the vector value being written is the same as the element type. 582 auto memrefElTy = memRefType.getElementType(); 583 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType()) 584 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 585 diag << "elemental type mismatch: " << write; 586 }); 587 588 // Otherwise, element types of the memref and the vector must match. 589 if (!isa<VectorType>(memrefElTy) && 590 memrefElTy != write.getVectorType().getElementType()) 591 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 592 diag << "elemental type mismatch: " << write; 593 }); 594 595 // Out-of-bounds dims are handled by MaterializeTransferMask. 596 if (write.hasOutOfBoundsDim()) 597 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 598 diag << "out of bounds dim: " << write; 599 }); 600 if (write.getMask()) { 601 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 602 write, write.getSource(), write.getIndices(), write.getMask(), 603 write.getVector()); 604 } else { 605 rewriter.replaceOpWithNewOp<vector::StoreOp>( 606 write, write.getVector(), write.getSource(), write.getIndices()); 607 } 608 return success(); 609 } 610 611 std::optional<unsigned> maxTransferRank; 612 }; 613 } // namespace 614 615 void mlir::vector::populateVectorTransferLoweringPatterns( 616 RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank, 617 PatternBenefit benefit) { 618 patterns.add<TransferReadToVectorLoadLowering, 619 TransferWriteToVectorStoreLowering>(patterns.getContext(), 620 maxTransferRank, benefit); 621 patterns 622 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 623 patterns.getContext(), benefit); 624 } 625