1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewrite patterns for the permutation_map attribute of 10 // vector.transfer operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 18 #include "mlir/Interfaces/VectorInterfaces.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse 24 /// permutation based on the given indices. 25 static ArrayAttr 26 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, 27 const SmallVector<unsigned> &permutation) { 28 SmallVector<bool> newInBoundsValues(permutation.size()); 29 size_t index = 0; 30 for (unsigned pos : permutation) 31 newInBoundsValues[pos] = 32 cast<BoolAttr>(attr.getValue()[index++]).getValue(); 33 return builder.getBoolArrayAttr(newInBoundsValues); 34 } 35 36 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit 37 /// dimensions. 38 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, 39 int64_t addedRank) { 40 auto originalVecType = cast<VectorType>(vec.getType()); 41 SmallVector<int64_t> newShape(addedRank, 1); 42 newShape.append(originalVecType.getShape().begin(), 43 originalVecType.getShape().end()); 44 45 SmallVector<bool> newScalableDims(addedRank, false); 46 newScalableDims.append(originalVecType.getScalableDims().begin(), 47 originalVecType.getScalableDims().end()); 48 VectorType newVecType = VectorType::get( 49 newShape, originalVecType.getElementType(), newScalableDims); 50 return builder.create<vector::BroadcastOp>(loc, newVecType, vec); 51 } 52 53 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit 54 /// dimensions. 55 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, 56 int64_t addedRank) { 57 Value broadcasted = extendVectorRank(builder, loc, vec, addedRank); 58 SmallVector<int64_t> permutation; 59 for (int64_t i = addedRank, 60 e = cast<VectorType>(broadcasted.getType()).getRank(); 61 i < e; ++i) 62 permutation.push_back(i); 63 for (int64_t i = 0; i < addedRank; ++i) 64 permutation.push_back(i); 65 return builder.create<vector::TransposeOp>(loc, broadcasted, permutation); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // populateVectorTransferPermutationMapLoweringPatterns 70 //===----------------------------------------------------------------------===// 71 72 namespace { 73 /// Lower transfer_read op with permutation into a transfer_read with a 74 /// permutation map composed of leading zeros followed by a minor identiy + 75 /// vector.transpose op. 76 /// Ex: 77 /// vector.transfer_read ... 78 /// permutation_map: (d0, d1, d2) -> (0, d1) 79 /// into: 80 /// %v = vector.transfer_read ... 81 /// permutation_map: (d0, d1, d2) -> (d1, 0) 82 /// vector.transpose %v, [1, 0] 83 /// 84 /// vector.transfer_read ... 85 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) 86 /// into: 87 /// %v = vector.transfer_read ... 88 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) 89 /// vector.transpose %v, [0, 1, 3, 2, 4] 90 /// Note that an alternative is to transform it to linalg.transpose + 91 /// vector.transfer_read to do the transpose in memory instead. 92 struct TransferReadPermutationLowering 93 : public MaskableOpRewritePattern<vector::TransferReadOp> { 94 using MaskableOpRewritePattern::MaskableOpRewritePattern; 95 96 FailureOr<mlir::Value> 97 matchAndRewriteMaskableOp(vector::TransferReadOp op, 98 MaskingOpInterface maskOp, 99 PatternRewriter &rewriter) const override { 100 // TODO: support 0-d corner case. 101 if (op.getTransferRank() == 0) 102 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 103 // TODO: Support transfer_read inside MaskOp case. 104 if (maskOp) 105 return rewriter.notifyMatchFailure(op, "Masked case not supported"); 106 107 SmallVector<unsigned> permutation; 108 AffineMap map = op.getPermutationMap(); 109 if (map.getNumResults() == 0) 110 return rewriter.notifyMatchFailure(op, "0 result permutation map"); 111 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 112 return rewriter.notifyMatchFailure( 113 op, "map is not permutable to minor identity, apply another pattern"); 114 } 115 AffineMap permutationMap = 116 map.getPermutationMap(permutation, op.getContext()); 117 if (permutationMap.isIdentity()) 118 return rewriter.notifyMatchFailure(op, "map is not identity"); 119 120 permutationMap = map.getPermutationMap(permutation, op.getContext()); 121 // Caluclate the map of the new read by applying the inverse permutation. 122 permutationMap = inversePermutation(permutationMap); 123 AffineMap newMap = permutationMap.compose(map); 124 // Apply the reverse transpose to deduce the type of the transfer_read. 125 ArrayRef<int64_t> originalShape = op.getVectorType().getShape(); 126 SmallVector<int64_t> newVectorShape(originalShape.size()); 127 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims(); 128 SmallVector<bool> newScalableDims(originalShape.size()); 129 for (const auto &pos : llvm::enumerate(permutation)) { 130 newVectorShape[pos.value()] = originalShape[pos.index()]; 131 newScalableDims[pos.value()] = originalScalableDims[pos.index()]; 132 } 133 134 // Transpose in_bounds attribute. 135 ArrayAttr newInBoundsAttr = 136 inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation); 137 138 // Generate new transfer_read operation. 139 VectorType newReadType = VectorType::get( 140 newVectorShape, op.getVectorType().getElementType(), newScalableDims); 141 Value newRead = rewriter.create<vector::TransferReadOp>( 142 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 143 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 144 newInBoundsAttr); 145 146 // Transpose result of transfer_read. 147 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); 148 return rewriter 149 .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm) 150 .getResult(); 151 } 152 }; 153 154 /// Lower transfer_write op with permutation into a transfer_write with a 155 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) 156 /// Ex: 157 /// vector.transfer_write %v ... 158 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) 159 /// into: 160 /// %tmp = vector.transpose %v, [2, 0, 1] 161 /// vector.transfer_write %tmp ... 162 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) 163 /// 164 /// vector.transfer_write %v ... 165 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) 166 /// into: 167 /// %tmp = vector.transpose %v, [1, 0] 168 /// %v = vector.transfer_write %tmp ... 169 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) 170 struct TransferWritePermutationLowering 171 : public MaskableOpRewritePattern<vector::TransferWriteOp> { 172 using MaskableOpRewritePattern::MaskableOpRewritePattern; 173 174 FailureOr<mlir::Value> 175 matchAndRewriteMaskableOp(vector::TransferWriteOp op, 176 MaskingOpInterface maskOp, 177 PatternRewriter &rewriter) const override { 178 // TODO: support 0-d corner case. 179 if (op.getTransferRank() == 0) 180 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 181 // TODO: Support transfer_write inside MaskOp case. 182 if (maskOp) 183 return rewriter.notifyMatchFailure(op, "Masked case not supported"); 184 185 SmallVector<unsigned> permutation; 186 AffineMap map = op.getPermutationMap(); 187 if (map.isMinorIdentity()) 188 return rewriter.notifyMatchFailure(op, "map is already minor identity"); 189 190 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 191 return rewriter.notifyMatchFailure( 192 op, "map is not permutable to minor identity, apply another pattern"); 193 } 194 195 // Remove unused dims from the permutation map. E.g.: 196 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) 197 // comp = (d0, d1, d2) -> (d2, d0, d1) 198 auto comp = compressUnusedDims(map); 199 AffineMap permutationMap = inversePermutation(comp); 200 // Get positions of remaining result dims. 201 SmallVector<int64_t> indices; 202 llvm::transform(permutationMap.getResults(), std::back_inserter(indices), 203 [](AffineExpr expr) { 204 return dyn_cast<AffineDimExpr>(expr).getPosition(); 205 }); 206 207 // Transpose in_bounds attribute. 208 ArrayAttr newInBoundsAttr = 209 inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation); 210 211 // Generate new transfer_write operation. 212 Value newVec = rewriter.create<vector::TransposeOp>( 213 op.getLoc(), op.getVector(), indices); 214 auto newMap = AffineMap::getMinorIdentityMap( 215 map.getNumDims(), map.getNumResults(), rewriter.getContext()); 216 auto newWrite = rewriter.create<vector::TransferWriteOp>( 217 op.getLoc(), newVec, op.getSource(), op.getIndices(), 218 AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr); 219 if (newWrite.hasPureTensorSemantics()) 220 return newWrite.getResult(); 221 // In the memref case there's no return value. Use empty value to signal 222 // success. 223 return Value(); 224 } 225 }; 226 227 /// Convert a transfer.write op with a map which isn't the permutation of a 228 /// minor identity into a vector.broadcast + transfer_write with permutation of 229 /// minor identity map by adding unit dim on inner dimension. Ex: 230 /// ``` 231 /// vector.transfer_write %v 232 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : 233 /// vector<8x16xf32> 234 /// ``` 235 /// into: 236 /// ``` 237 /// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32> 238 /// vector.transfer_write %v1 239 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} : 240 /// vector<1x8x16xf32> 241 /// ``` 242 struct TransferWriteNonPermutationLowering 243 : public MaskableOpRewritePattern<vector::TransferWriteOp> { 244 using MaskableOpRewritePattern::MaskableOpRewritePattern; 245 246 FailureOr<mlir::Value> 247 matchAndRewriteMaskableOp(vector::TransferWriteOp op, 248 MaskingOpInterface maskOp, 249 PatternRewriter &rewriter) const override { 250 // TODO: support 0-d corner case. 251 if (op.getTransferRank() == 0) 252 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 253 // TODO: Support transfer_write inside MaskOp case. 254 if (maskOp) 255 return rewriter.notifyMatchFailure(op, "Masked case not supported"); 256 257 SmallVector<unsigned> permutation; 258 AffineMap map = op.getPermutationMap(); 259 if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 260 return rewriter.notifyMatchFailure( 261 op, 262 "map is already permutable to minor identity, apply another pattern"); 263 } 264 265 // Missing outer dimensions are allowed, find the most outer existing 266 // dimension then deduce the missing inner dimensions. 267 SmallVector<bool> foundDim(map.getNumDims(), false); 268 for (AffineExpr exp : map.getResults()) 269 foundDim[cast<AffineDimExpr>(exp).getPosition()] = true; 270 SmallVector<AffineExpr> exprs; 271 bool foundFirstDim = false; 272 SmallVector<int64_t> missingInnerDim; 273 for (size_t i = 0; i < foundDim.size(); i++) { 274 if (foundDim[i]) { 275 foundFirstDim = true; 276 continue; 277 } 278 if (!foundFirstDim) 279 continue; 280 // Once we found one outer dimension existing in the map keep track of all 281 // the missing dimensions after that. 282 missingInnerDim.push_back(i); 283 exprs.push_back(rewriter.getAffineDimExpr(i)); 284 } 285 // Vector: add unit dims at the beginning of the shape. 286 Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(), 287 missingInnerDim.size()); 288 // Mask: add unit dims at the end of the shape. 289 Value newMask; 290 if (op.getMask()) 291 newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(), 292 missingInnerDim.size()); 293 exprs.append(map.getResults().begin(), map.getResults().end()); 294 AffineMap newMap = 295 AffineMap::get(map.getNumDims(), 0, exprs, op.getContext()); 296 // All the new dimensions added are inbound. 297 SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true); 298 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) { 299 newInBoundsValues.push_back(op.isDimInBounds(i)); 300 } 301 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); 302 auto newWrite = rewriter.create<vector::TransferWriteOp>( 303 op.getLoc(), newVec, op.getSource(), op.getIndices(), 304 AffineMapAttr::get(newMap), newMask, newInBoundsAttr); 305 if (newWrite.hasPureTensorSemantics()) 306 return newWrite.getResult(); 307 // In the memref case there's no return value. Use empty value to signal 308 // success. 309 return Value(); 310 } 311 }; 312 313 /// Lower transfer_read op with broadcast in the leading dimensions into 314 /// transfer_read of lower rank + vector.broadcast. 315 /// Ex: vector.transfer_read ... 316 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) 317 /// into: 318 /// %v = vector.transfer_read ... 319 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) 320 /// vector.broadcast %v 321 struct TransferOpReduceRank 322 : public MaskableOpRewritePattern<vector::TransferReadOp> { 323 using MaskableOpRewritePattern::MaskableOpRewritePattern; 324 325 FailureOr<mlir::Value> 326 matchAndRewriteMaskableOp(vector::TransferReadOp op, 327 MaskingOpInterface maskOp, 328 PatternRewriter &rewriter) const override { 329 // TODO: support 0-d corner case. 330 if (op.getTransferRank() == 0) 331 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 332 // TODO: support masked case. 333 if (maskOp) 334 return rewriter.notifyMatchFailure(op, "Masked case not supported"); 335 336 AffineMap map = op.getPermutationMap(); 337 unsigned numLeadingBroadcast = 0; 338 for (auto expr : map.getResults()) { 339 auto dimExpr = dyn_cast<AffineConstantExpr>(expr); 340 if (!dimExpr || dimExpr.getValue() != 0) 341 break; 342 numLeadingBroadcast++; 343 } 344 // If there are no leading zeros in the map there is nothing to do. 345 if (numLeadingBroadcast == 0) 346 return rewriter.notifyMatchFailure(op, "no leading broadcasts in map"); 347 348 VectorType originalVecType = op.getVectorType(); 349 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; 350 // Calculate new map, vector type and masks without the leading zeros. 351 AffineMap newMap = AffineMap::get( 352 map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), 353 op.getContext()); 354 // Only remove the leading zeros if the rest of the map is a minor identity 355 // with broadasting. Otherwise we first want to permute the map. 356 if (!newMap.isMinorIdentityWithBroadcasting()) { 357 return rewriter.notifyMatchFailure( 358 op, "map is not a minor identity with broadcasting"); 359 } 360 361 SmallVector<int64_t> newShape( 362 originalVecType.getShape().take_back(reducedShapeRank)); 363 SmallVector<bool> newScalableDims( 364 originalVecType.getScalableDims().take_back(reducedShapeRank)); 365 366 VectorType newReadType = VectorType::get( 367 newShape, originalVecType.getElementType(), newScalableDims); 368 ArrayAttr newInBoundsAttr = 369 op.getInBounds() 370 ? rewriter.getArrayAttr( 371 op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) 372 : ArrayAttr(); 373 Value newRead = rewriter.create<vector::TransferReadOp>( 374 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 375 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 376 newInBoundsAttr); 377 return rewriter 378 .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead) 379 .getVector(); 380 } 381 }; 382 383 } // namespace 384 385 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 386 RewritePatternSet &patterns, PatternBenefit benefit) { 387 patterns 388 .add<TransferReadPermutationLowering, TransferWritePermutationLowering, 389 TransferOpReduceRank, TransferWriteNonPermutationLowering>( 390 patterns.getContext(), benefit); 391 } 392 393 //===----------------------------------------------------------------------===// 394 // populateVectorTransferLoweringPatterns 395 //===----------------------------------------------------------------------===// 396 397 namespace { 398 /// Progressive lowering of transfer_read. This pattern supports lowering of 399 /// `vector.transfer_read` to a combination of `vector.load` and 400 /// `vector.broadcast` if all of the following hold: 401 /// - Stride of most minor memref dimension must be 1. 402 /// - Out-of-bounds masking is not required. 403 /// - If the memref's element type is a vector type then it coincides with the 404 /// result type. 405 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 406 struct TransferReadToVectorLoadLowering 407 : public MaskableOpRewritePattern<vector::TransferReadOp> { 408 TransferReadToVectorLoadLowering(MLIRContext *context, 409 std::optional<unsigned> maxRank, 410 PatternBenefit benefit = 1) 411 : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit), 412 maxTransferRank(maxRank) {} 413 414 FailureOr<mlir::Value> 415 matchAndRewriteMaskableOp(vector::TransferReadOp read, 416 MaskingOpInterface maskOp, 417 PatternRewriter &rewriter) const override { 418 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { 419 return rewriter.notifyMatchFailure( 420 read, "vector type is greater than max transfer rank"); 421 } 422 423 if (maskOp) 424 return rewriter.notifyMatchFailure(read, "Masked case not supported"); 425 SmallVector<unsigned> broadcastedDims; 426 // Permutations are handled by VectorToSCF or 427 // populateVectorTransferPermutationMapLoweringPatterns. 428 // We let the 0-d corner case pass-through as it is supported. 429 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( 430 &broadcastedDims)) 431 return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); 432 433 auto memRefType = dyn_cast<MemRefType>(read.getShapedType()); 434 if (!memRefType) 435 return rewriter.notifyMatchFailure(read, "not a memref source"); 436 437 // Non-unit strides are handled by VectorToSCF. 438 if (!memRefType.isLastDimUnitStride()) 439 return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); 440 441 // If there is broadcasting involved then we first load the unbroadcasted 442 // vector, and then broadcast it with `vector.broadcast`. 443 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 444 SmallVector<int64_t> unbroadcastedVectorShape(vectorShape); 445 for (unsigned i : broadcastedDims) 446 unbroadcastedVectorShape[i] = 1; 447 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith( 448 unbroadcastedVectorShape, read.getVectorType().getElementType()); 449 450 // `vector.load` supports vector types as memref's elements only when the 451 // resulting vector type is the same as the element type. 452 auto memrefElTy = memRefType.getElementType(); 453 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType) 454 return rewriter.notifyMatchFailure(read, "incompatible element type"); 455 456 // Otherwise, element types of the memref and the vector must match. 457 if (!isa<VectorType>(memrefElTy) && 458 memrefElTy != read.getVectorType().getElementType()) 459 return rewriter.notifyMatchFailure(read, "non-matching element type"); 460 461 // Out-of-bounds dims are handled by MaterializeTransferMask. 462 if (read.hasOutOfBoundsDim()) 463 return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); 464 465 // Create vector load op. 466 Operation *res; 467 if (read.getMask()) { 468 if (read.getVectorType().getRank() != 1) 469 // vector.maskedload operates on 1-D vectors. 470 return rewriter.notifyMatchFailure( 471 read, "vector type is not rank 1, can't create masked load, needs " 472 "VectorToSCF"); 473 474 Value fill = rewriter.create<vector::SplatOp>( 475 read.getLoc(), unbroadcastedVectorType, read.getPadding()); 476 res = rewriter.create<vector::MaskedLoadOp>( 477 read.getLoc(), unbroadcastedVectorType, read.getSource(), 478 read.getIndices(), read.getMask(), fill); 479 } else { 480 res = rewriter.create<vector::LoadOp>( 481 read.getLoc(), unbroadcastedVectorType, read.getSource(), 482 read.getIndices()); 483 } 484 485 // Insert a broadcasting op if required. 486 if (!broadcastedDims.empty()) 487 res = rewriter.create<vector::BroadcastOp>( 488 read.getLoc(), read.getVectorType(), res->getResult(0)); 489 return res->getResult(0); 490 } 491 492 std::optional<unsigned> maxTransferRank; 493 }; 494 495 /// Progressive lowering of transfer_write. This pattern supports lowering of 496 /// `vector.transfer_write` to `vector.store` if all of the following hold: 497 /// - Stride of most minor memref dimension must be 1. 498 /// - Out-of-bounds masking is not required. 499 /// - If the memref's element type is a vector type then it coincides with the 500 /// type of the written value. 501 /// - The permutation map is the minor identity map (neither permutation nor 502 /// broadcasting is allowed). 503 struct TransferWriteToVectorStoreLowering 504 : public MaskableOpRewritePattern<vector::TransferWriteOp> { 505 TransferWriteToVectorStoreLowering(MLIRContext *context, 506 std::optional<unsigned> maxRank, 507 PatternBenefit benefit = 1) 508 : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit), 509 maxTransferRank(maxRank) {} 510 511 FailureOr<mlir::Value> 512 matchAndRewriteMaskableOp(vector::TransferWriteOp write, 513 MaskingOpInterface maskOp, 514 PatternRewriter &rewriter) const override { 515 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { 516 return rewriter.notifyMatchFailure( 517 write, "vector type is greater than max transfer rank"); 518 } 519 if (maskOp) 520 return rewriter.notifyMatchFailure(write, "Masked case not supported"); 521 522 // Permutations are handled by VectorToSCF or 523 // populateVectorTransferPermutationMapLoweringPatterns. 524 if ( // pass-through for the 0-d corner case. 525 !write.getPermutationMap().isMinorIdentity()) 526 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 527 diag << "permutation map is not minor identity: " << write; 528 }); 529 530 auto memRefType = dyn_cast<MemRefType>(write.getShapedType()); 531 if (!memRefType) 532 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 533 diag << "not a memref type: " << write; 534 }); 535 536 // Non-unit strides are handled by VectorToSCF. 537 if (!memRefType.isLastDimUnitStride()) 538 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 539 diag << "most minor stride is not 1: " << write; 540 }); 541 542 // `vector.store` supports vector types as memref's elements only when the 543 // type of the vector value being written is the same as the element type. 544 auto memrefElTy = memRefType.getElementType(); 545 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType()) 546 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 547 diag << "elemental type mismatch: " << write; 548 }); 549 550 // Otherwise, element types of the memref and the vector must match. 551 if (!isa<VectorType>(memrefElTy) && 552 memrefElTy != write.getVectorType().getElementType()) 553 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 554 diag << "elemental type mismatch: " << write; 555 }); 556 557 // Out-of-bounds dims are handled by MaterializeTransferMask. 558 if (write.hasOutOfBoundsDim()) 559 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 560 diag << "out of bounds dim: " << write; 561 }); 562 if (write.getMask()) { 563 if (write.getVectorType().getRank() != 1) 564 // vector.maskedstore operates on 1-D vectors. 565 return rewriter.notifyMatchFailure( 566 write.getLoc(), [=](Diagnostic &diag) { 567 diag << "vector type is not rank 1, can't create masked store, " 568 "needs VectorToSCF: " 569 << write; 570 }); 571 572 rewriter.create<vector::MaskedStoreOp>( 573 write.getLoc(), write.getSource(), write.getIndices(), 574 write.getMask(), write.getVector()); 575 } else { 576 rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(), 577 write.getSource(), write.getIndices()); 578 } 579 // There's no return value for StoreOps. Use Value() to signal success to 580 // matchAndRewrite. 581 return Value(); 582 } 583 584 std::optional<unsigned> maxTransferRank; 585 }; 586 } // namespace 587 588 void mlir::vector::populateVectorTransferLoweringPatterns( 589 RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank, 590 PatternBenefit benefit) { 591 patterns.add<TransferReadToVectorLoadLowering, 592 TransferWriteToVectorStoreLowering>(patterns.getContext(), 593 maxTransferRank, benefit); 594 } 595