1 //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements rewrite patterns for the permutation_map attribute of 10 // vector.transfer operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 18 #include "mlir/Interfaces/VectorInterfaces.h" 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 /// Transpose a vector transfer op's `in_bounds` attribute by applying reverse 24 /// permutation based on the given indices. 25 static ArrayAttr 26 inverseTransposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, 27 const SmallVector<unsigned> &permutation) { 28 SmallVector<bool> newInBoundsValues(permutation.size()); 29 size_t index = 0; 30 for (unsigned pos : permutation) 31 newInBoundsValues[pos] = 32 cast<BoolAttr>(attr.getValue()[index++]).getValue(); 33 return builder.getBoolArrayAttr(newInBoundsValues); 34 } 35 36 /// Extend the rank of a vector Value by `addedRanks` by adding outer unit 37 /// dimensions. 38 static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, 39 int64_t addedRank) { 40 auto originalVecType = cast<VectorType>(vec.getType()); 41 SmallVector<int64_t> newShape(addedRank, 1); 42 newShape.append(originalVecType.getShape().begin(), 43 originalVecType.getShape().end()); 44 45 SmallVector<bool> newScalableDims(addedRank, false); 46 newScalableDims.append(originalVecType.getScalableDims().begin(), 47 originalVecType.getScalableDims().end()); 48 VectorType newVecType = VectorType::get( 49 newShape, originalVecType.getElementType(), newScalableDims); 50 return builder.create<vector::BroadcastOp>(loc, newVecType, vec); 51 } 52 53 /// Extend the rank of a vector Value by `addedRanks` by adding inner unit 54 /// dimensions. 55 static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, 56 int64_t addedRank) { 57 Value broadcasted = extendVectorRank(builder, loc, vec, addedRank); 58 SmallVector<int64_t> permutation; 59 for (int64_t i = addedRank, 60 e = cast<VectorType>(broadcasted.getType()).getRank(); 61 i < e; ++i) 62 permutation.push_back(i); 63 for (int64_t i = 0; i < addedRank; ++i) 64 permutation.push_back(i); 65 return builder.create<vector::TransposeOp>(loc, broadcasted, permutation); 66 } 67 68 //===----------------------------------------------------------------------===// 69 // populateVectorTransferPermutationMapLoweringPatterns 70 //===----------------------------------------------------------------------===// 71 72 namespace { 73 /// Lower transfer_read op with permutation into a transfer_read with a 74 /// permutation map composed of leading zeros followed by a minor identiy + 75 /// vector.transpose op. 76 /// Ex: 77 /// vector.transfer_read ... 78 /// permutation_map: (d0, d1, d2) -> (0, d1) 79 /// into: 80 /// %v = vector.transfer_read ... 81 /// permutation_map: (d0, d1, d2) -> (d1, 0) 82 /// vector.transpose %v, [1, 0] 83 /// 84 /// vector.transfer_read ... 85 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) 86 /// into: 87 /// %v = vector.transfer_read ... 88 /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) 89 /// vector.transpose %v, [0, 1, 3, 2, 4] 90 /// Note that an alternative is to transform it to linalg.transpose + 91 /// vector.transfer_read to do the transpose in memory instead. 92 struct TransferReadPermutationLowering 93 : public MaskableOpRewritePattern<vector::TransferReadOp> { 94 using MaskableOpRewritePattern::MaskableOpRewritePattern; 95 96 FailureOr<mlir::Value> 97 matchAndRewriteMaskableOp(vector::TransferReadOp op, 98 MaskingOpInterface maskOp, 99 PatternRewriter &rewriter) const override { 100 // TODO: support 0-d corner case. 101 if (op.getTransferRank() == 0) 102 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 103 // TODO: Support transfer_read inside MaskOp case. 104 if (maskOp) 105 return rewriter.notifyMatchFailure(op, "Masked case not supported"); 106 107 SmallVector<unsigned> permutation; 108 AffineMap map = op.getPermutationMap(); 109 if (map.getNumResults() == 0) 110 return rewriter.notifyMatchFailure(op, "0 result permutation map"); 111 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 112 return rewriter.notifyMatchFailure( 113 op, "map is not permutable to minor identity, apply another pattern"); 114 } 115 AffineMap permutationMap = 116 map.getPermutationMap(permutation, op.getContext()); 117 if (permutationMap.isIdentity()) 118 return rewriter.notifyMatchFailure(op, "map is not identity"); 119 120 permutationMap = map.getPermutationMap(permutation, op.getContext()); 121 // Caluclate the map of the new read by applying the inverse permutation. 122 permutationMap = inversePermutation(permutationMap); 123 AffineMap newMap = permutationMap.compose(map); 124 // Apply the reverse transpose to deduce the type of the transfer_read. 125 ArrayRef<int64_t> originalShape = op.getVectorType().getShape(); 126 SmallVector<int64_t> newVectorShape(originalShape.size()); 127 ArrayRef<bool> originalScalableDims = op.getVectorType().getScalableDims(); 128 SmallVector<bool> newScalableDims(originalShape.size()); 129 for (const auto &pos : llvm::enumerate(permutation)) { 130 newVectorShape[pos.value()] = originalShape[pos.index()]; 131 newScalableDims[pos.value()] = originalScalableDims[pos.index()]; 132 } 133 134 // Transpose in_bounds attribute. 135 ArrayAttr newInBoundsAttr = 136 inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation); 137 138 // Generate new transfer_read operation. 139 VectorType newReadType = VectorType::get( 140 newVectorShape, op.getVectorType().getElementType(), newScalableDims); 141 Value newRead = rewriter.create<vector::TransferReadOp>( 142 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 143 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 144 newInBoundsAttr); 145 146 // Transpose result of transfer_read. 147 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); 148 return rewriter 149 .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm) 150 .getResult(); 151 } 152 }; 153 154 /// Lower transfer_write op with permutation into a transfer_write with a 155 /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) 156 /// Ex: 157 /// vector.transfer_write %v ... 158 /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) 159 /// into: 160 /// %tmp = vector.transpose %v, [2, 0, 1] 161 /// vector.transfer_write %tmp ... 162 /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) 163 /// 164 /// vector.transfer_write %v ... 165 /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) 166 /// into: 167 /// %tmp = vector.transpose %v, [1, 0] 168 /// %v = vector.transfer_write %tmp ... 169 /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) 170 struct TransferWritePermutationLowering 171 : public MaskableOpRewritePattern<vector::TransferWriteOp> { 172 using MaskableOpRewritePattern::MaskableOpRewritePattern; 173 174 FailureOr<mlir::Value> 175 matchAndRewriteMaskableOp(vector::TransferWriteOp op, 176 MaskingOpInterface maskOp, 177 PatternRewriter &rewriter) const override { 178 // TODO: support 0-d corner case. 179 if (op.getTransferRank() == 0) 180 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 181 // TODO: Support transfer_write inside MaskOp case. 182 if (maskOp) 183 return rewriter.notifyMatchFailure(op, "Masked case not supported"); 184 185 SmallVector<unsigned> permutation; 186 AffineMap map = op.getPermutationMap(); 187 if (map.isMinorIdentity()) 188 return rewriter.notifyMatchFailure(op, "map is already minor identity"); 189 190 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 191 return rewriter.notifyMatchFailure( 192 op, "map is not permutable to minor identity, apply another pattern"); 193 } 194 195 // Remove unused dims from the permutation map. E.g.: 196 // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) 197 // comp = (d0, d1, d2) -> (d2, d0, d1) 198 auto comp = compressUnusedDims(map); 199 AffineMap permutationMap = inversePermutation(comp); 200 // Get positions of remaining result dims. 201 SmallVector<int64_t> indices; 202 llvm::transform(permutationMap.getResults(), std::back_inserter(indices), 203 [](AffineExpr expr) { 204 return dyn_cast<AffineDimExpr>(expr).getPosition(); 205 }); 206 207 // Transpose in_bounds attribute. 208 ArrayAttr newInBoundsAttr = 209 inverseTransposeInBoundsAttr(rewriter, op.getInBounds(), permutation); 210 211 // Generate new transfer_write operation. 212 Value newVec = rewriter.create<vector::TransposeOp>( 213 op.getLoc(), op.getVector(), indices); 214 auto newMap = AffineMap::getMinorIdentityMap( 215 map.getNumDims(), map.getNumResults(), rewriter.getContext()); 216 auto newWrite = rewriter.create<vector::TransferWriteOp>( 217 op.getLoc(), newVec, op.getSource(), op.getIndices(), 218 AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr); 219 if (newWrite.hasPureTensorSemantics()) 220 return newWrite.getResult(); 221 // In the memref case there's no return value. Use empty value to signal 222 // success. 223 return Value(); 224 } 225 }; 226 227 /// Convert a transfer.write op with a map which isn't the permutation of a 228 /// minor identity into a vector.broadcast + transfer_write with permutation of 229 /// minor identity map by adding unit dim on inner dimension. Ex: 230 /// ``` 231 /// vector.transfer_write %v 232 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : 233 /// vector<8x16xf32> 234 /// ``` 235 /// into: 236 /// ``` 237 /// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32> 238 /// vector.transfer_write %v1 239 /// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} : 240 /// vector<1x8x16xf32> 241 /// ``` 242 struct TransferWriteNonPermutationLowering 243 : public MaskableOpRewritePattern<vector::TransferWriteOp> { 244 using MaskableOpRewritePattern::MaskableOpRewritePattern; 245 246 FailureOr<mlir::Value> 247 matchAndRewriteMaskableOp(vector::TransferWriteOp op, 248 MaskingOpInterface maskOp, 249 PatternRewriter &rewriter) const override { 250 // TODO: support 0-d corner case. 251 if (op.getTransferRank() == 0) 252 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 253 // TODO: Support transfer_write inside MaskOp case. 254 if (maskOp) 255 return rewriter.notifyMatchFailure(op, "Masked case not supported"); 256 257 SmallVector<unsigned> permutation; 258 AffineMap map = op.getPermutationMap(); 259 if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) { 260 return rewriter.notifyMatchFailure( 261 op, 262 "map is already permutable to minor identity, apply another pattern"); 263 } 264 265 // Missing outer dimensions are allowed, find the most outer existing 266 // dimension then deduce the missing inner dimensions. 267 SmallVector<bool> foundDim(map.getNumDims(), false); 268 for (AffineExpr exp : map.getResults()) 269 foundDim[cast<AffineDimExpr>(exp).getPosition()] = true; 270 SmallVector<AffineExpr> exprs; 271 bool foundFirstDim = false; 272 SmallVector<int64_t> missingInnerDim; 273 for (size_t i = 0; i < foundDim.size(); i++) { 274 if (foundDim[i]) { 275 foundFirstDim = true; 276 continue; 277 } 278 if (!foundFirstDim) 279 continue; 280 // Once we found one outer dimension existing in the map keep track of all 281 // the missing dimensions after that. 282 missingInnerDim.push_back(i); 283 exprs.push_back(rewriter.getAffineDimExpr(i)); 284 } 285 // Vector: add unit dims at the beginning of the shape. 286 Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(), 287 missingInnerDim.size()); 288 // Mask: add unit dims at the end of the shape. 289 Value newMask; 290 if (op.getMask()) 291 newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(), 292 missingInnerDim.size()); 293 exprs.append(map.getResults().begin(), map.getResults().end()); 294 AffineMap newMap = 295 AffineMap::get(map.getNumDims(), 0, exprs, op.getContext()); 296 // All the new dimensions added are inbound. 297 SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true); 298 for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) { 299 newInBoundsValues.push_back(op.isDimInBounds(i)); 300 } 301 ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); 302 auto newWrite = rewriter.create<vector::TransferWriteOp>( 303 op.getLoc(), newVec, op.getSource(), op.getIndices(), 304 AffineMapAttr::get(newMap), newMask, newInBoundsAttr); 305 if (newWrite.hasPureTensorSemantics()) 306 return newWrite.getResult(); 307 // In the memref case there's no return value. Use empty value to signal 308 // success. 309 return Value(); 310 } 311 }; 312 313 /// Lower transfer_read op with broadcast in the leading dimensions into 314 /// transfer_read of lower rank + vector.broadcast. 315 /// Ex: vector.transfer_read ... 316 /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) 317 /// into: 318 /// %v = vector.transfer_read ... 319 /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) 320 /// vector.broadcast %v 321 struct TransferOpReduceRank 322 : public MaskableOpRewritePattern<vector::TransferReadOp> { 323 using MaskableOpRewritePattern::MaskableOpRewritePattern; 324 325 FailureOr<mlir::Value> 326 matchAndRewriteMaskableOp(vector::TransferReadOp op, 327 MaskingOpInterface maskOp, 328 PatternRewriter &rewriter) const override { 329 // TODO: support 0-d corner case. 330 if (op.getTransferRank() == 0) 331 return rewriter.notifyMatchFailure(op, "0-d corner case not supported"); 332 // TODO: support masked case. 333 if (maskOp) 334 return rewriter.notifyMatchFailure(op, "Masked case not supported"); 335 336 AffineMap map = op.getPermutationMap(); 337 unsigned numLeadingBroadcast = 0; 338 for (auto expr : map.getResults()) { 339 auto dimExpr = dyn_cast<AffineConstantExpr>(expr); 340 if (!dimExpr || dimExpr.getValue() != 0) 341 break; 342 numLeadingBroadcast++; 343 } 344 // If there are no leading zeros in the map there is nothing to do. 345 if (numLeadingBroadcast == 0) 346 return rewriter.notifyMatchFailure(op, "no leading broadcasts in map"); 347 348 VectorType originalVecType = op.getVectorType(); 349 unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; 350 // Calculate new map, vector type and masks without the leading zeros. 351 AffineMap newMap = AffineMap::get( 352 map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), 353 op.getContext()); 354 // Only remove the leading zeros if the rest of the map is a minor identity 355 // with broadasting. Otherwise we first want to permute the map. 356 if (!newMap.isMinorIdentityWithBroadcasting()) { 357 return rewriter.notifyMatchFailure( 358 op, "map is not a minor identity with broadcasting"); 359 } 360 361 // TODO: support zero-dimension vectors natively. See: 362 // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. 363 // In the meantime, lower these to a scalar load when they pop up. 364 if (reducedShapeRank == 0) { 365 Value newRead; 366 if (isa<TensorType>(op.getShapedType())) { 367 newRead = rewriter.create<tensor::ExtractOp>( 368 op.getLoc(), op.getSource(), op.getIndices()); 369 } else { 370 newRead = rewriter.create<memref::LoadOp>( 371 op.getLoc(), originalVecType.getElementType(), op.getSource(), 372 op.getIndices()); 373 } 374 return rewriter 375 .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead) 376 .getVector(); 377 } 378 379 SmallVector<int64_t> newShape( 380 originalVecType.getShape().take_back(reducedShapeRank)); 381 SmallVector<bool> newScalableDims( 382 originalVecType.getScalableDims().take_back(reducedShapeRank)); 383 // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. 384 if (newShape.empty()) 385 return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d"); 386 387 VectorType newReadType = VectorType::get( 388 newShape, originalVecType.getElementType(), newScalableDims); 389 ArrayAttr newInBoundsAttr = 390 op.getInBounds() 391 ? rewriter.getArrayAttr( 392 op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) 393 : ArrayAttr(); 394 Value newRead = rewriter.create<vector::TransferReadOp>( 395 op.getLoc(), newReadType, op.getSource(), op.getIndices(), 396 AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), 397 newInBoundsAttr); 398 return rewriter 399 .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead) 400 .getVector(); 401 } 402 }; 403 404 } // namespace 405 406 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 407 RewritePatternSet &patterns, PatternBenefit benefit) { 408 patterns 409 .add<TransferReadPermutationLowering, TransferWritePermutationLowering, 410 TransferOpReduceRank, TransferWriteNonPermutationLowering>( 411 patterns.getContext(), benefit); 412 } 413 414 //===----------------------------------------------------------------------===// 415 // populateVectorTransferLoweringPatterns 416 //===----------------------------------------------------------------------===// 417 418 namespace { 419 /// Progressive lowering of transfer_read. This pattern supports lowering of 420 /// `vector.transfer_read` to a combination of `vector.load` and 421 /// `vector.broadcast` if all of the following hold: 422 /// - Stride of most minor memref dimension must be 1. 423 /// - Out-of-bounds masking is not required. 424 /// - If the memref's element type is a vector type then it coincides with the 425 /// result type. 426 /// - The permutation map doesn't perform permutation (broadcasting is allowed). 427 struct TransferReadToVectorLoadLowering 428 : public MaskableOpRewritePattern<vector::TransferReadOp> { 429 TransferReadToVectorLoadLowering(MLIRContext *context, 430 std::optional<unsigned> maxRank, 431 PatternBenefit benefit = 1) 432 : MaskableOpRewritePattern<vector::TransferReadOp>(context, benefit), 433 maxTransferRank(maxRank) {} 434 435 FailureOr<mlir::Value> 436 matchAndRewriteMaskableOp(vector::TransferReadOp read, 437 MaskingOpInterface maskOp, 438 PatternRewriter &rewriter) const override { 439 if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) { 440 return rewriter.notifyMatchFailure( 441 read, "vector type is greater than max transfer rank"); 442 } 443 444 if (maskOp) 445 return rewriter.notifyMatchFailure(read, "Masked case not supported"); 446 SmallVector<unsigned> broadcastedDims; 447 // Permutations are handled by VectorToSCF or 448 // populateVectorTransferPermutationMapLoweringPatterns. 449 // We let the 0-d corner case pass-through as it is supported. 450 if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( 451 &broadcastedDims)) 452 return rewriter.notifyMatchFailure(read, "not minor identity + bcast"); 453 454 auto memRefType = dyn_cast<MemRefType>(read.getShapedType()); 455 if (!memRefType) 456 return rewriter.notifyMatchFailure(read, "not a memref source"); 457 458 // Non-unit strides are handled by VectorToSCF. 459 if (!isLastMemrefDimUnitStride(memRefType)) 460 return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); 461 462 // If there is broadcasting involved then we first load the unbroadcasted 463 // vector, and then broadcast it with `vector.broadcast`. 464 ArrayRef<int64_t> vectorShape = read.getVectorType().getShape(); 465 SmallVector<int64_t> unbroadcastedVectorShape(vectorShape); 466 for (unsigned i : broadcastedDims) 467 unbroadcastedVectorShape[i] = 1; 468 VectorType unbroadcastedVectorType = read.getVectorType().cloneWith( 469 unbroadcastedVectorShape, read.getVectorType().getElementType()); 470 471 // `vector.load` supports vector types as memref's elements only when the 472 // resulting vector type is the same as the element type. 473 auto memrefElTy = memRefType.getElementType(); 474 if (isa<VectorType>(memrefElTy) && memrefElTy != unbroadcastedVectorType) 475 return rewriter.notifyMatchFailure(read, "incompatible element type"); 476 477 // Otherwise, element types of the memref and the vector must match. 478 if (!isa<VectorType>(memrefElTy) && 479 memrefElTy != read.getVectorType().getElementType()) 480 return rewriter.notifyMatchFailure(read, "non-matching element type"); 481 482 // Out-of-bounds dims are handled by MaterializeTransferMask. 483 if (read.hasOutOfBoundsDim()) 484 return rewriter.notifyMatchFailure(read, "out-of-bounds needs mask"); 485 486 // Create vector load op. 487 Operation *res; 488 if (read.getMask()) { 489 if (read.getVectorType().getRank() != 1) 490 // vector.maskedload operates on 1-D vectors. 491 return rewriter.notifyMatchFailure( 492 read, "vector type is not rank 1, can't create masked load, needs " 493 "VectorToSCF"); 494 495 Value fill = rewriter.create<vector::SplatOp>( 496 read.getLoc(), unbroadcastedVectorType, read.getPadding()); 497 res = rewriter.create<vector::MaskedLoadOp>( 498 read.getLoc(), unbroadcastedVectorType, read.getSource(), 499 read.getIndices(), read.getMask(), fill); 500 } else { 501 res = rewriter.create<vector::LoadOp>( 502 read.getLoc(), unbroadcastedVectorType, read.getSource(), 503 read.getIndices()); 504 } 505 506 // Insert a broadcasting op if required. 507 if (!broadcastedDims.empty()) 508 res = rewriter.create<vector::BroadcastOp>( 509 read.getLoc(), read.getVectorType(), res->getResult(0)); 510 return res->getResult(0); 511 } 512 513 std::optional<unsigned> maxTransferRank; 514 }; 515 516 /// Replace a 0-d vector.load with a memref.load + vector.broadcast. 517 // TODO: we shouldn't cross the vector/scalar domains just for this 518 // but atm we lack the infra to avoid it. Possible solutions include: 519 // - go directly to LLVM + bitcast 520 // - introduce a bitcast op and likely a new pointer dialect 521 // - let memref.load/store additionally support the 0-d vector case 522 // There are still deeper data layout issues lingering even in this 523 // trivial case (for architectures for which this matters). 524 struct VectorLoadToMemrefLoadLowering 525 : public OpRewritePattern<vector::LoadOp> { 526 using OpRewritePattern::OpRewritePattern; 527 528 LogicalResult matchAndRewrite(vector::LoadOp loadOp, 529 PatternRewriter &rewriter) const override { 530 auto vecType = loadOp.getVectorType(); 531 if (vecType.getNumElements() != 1) 532 return rewriter.notifyMatchFailure(loadOp, "not a single element vector"); 533 534 auto memrefLoad = rewriter.create<memref::LoadOp>( 535 loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); 536 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType, 537 memrefLoad); 538 return success(); 539 } 540 }; 541 542 /// Replace a 0-d vector.store with a vector.extractelement + memref.store. 543 struct VectorStoreToMemrefStoreLowering 544 : public OpRewritePattern<vector::StoreOp> { 545 using OpRewritePattern::OpRewritePattern; 546 547 LogicalResult matchAndRewrite(vector::StoreOp storeOp, 548 PatternRewriter &rewriter) const override { 549 auto vecType = storeOp.getVectorType(); 550 if (vecType.getNumElements() != 1) 551 return rewriter.notifyMatchFailure(storeOp, "not single element vector"); 552 553 Value extracted; 554 if (vecType.getRank() == 0) { 555 // TODO: Unifiy once ExtractOp supports 0-d vectors. 556 extracted = rewriter.create<vector::ExtractElementOp>( 557 storeOp.getLoc(), storeOp.getValueToStore()); 558 } else { 559 SmallVector<int64_t> indices(vecType.getRank(), 0); 560 extracted = rewriter.create<vector::ExtractOp>( 561 storeOp.getLoc(), storeOp.getValueToStore(), indices); 562 } 563 564 rewriter.replaceOpWithNewOp<memref::StoreOp>( 565 storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); 566 return success(); 567 } 568 }; 569 570 /// Progressive lowering of transfer_write. This pattern supports lowering of 571 /// `vector.transfer_write` to `vector.store` if all of the following hold: 572 /// - Stride of most minor memref dimension must be 1. 573 /// - Out-of-bounds masking is not required. 574 /// - If the memref's element type is a vector type then it coincides with the 575 /// type of the written value. 576 /// - The permutation map is the minor identity map (neither permutation nor 577 /// broadcasting is allowed). 578 struct TransferWriteToVectorStoreLowering 579 : public MaskableOpRewritePattern<vector::TransferWriteOp> { 580 TransferWriteToVectorStoreLowering(MLIRContext *context, 581 std::optional<unsigned> maxRank, 582 PatternBenefit benefit = 1) 583 : MaskableOpRewritePattern<vector::TransferWriteOp>(context, benefit), 584 maxTransferRank(maxRank) {} 585 586 FailureOr<mlir::Value> 587 matchAndRewriteMaskableOp(vector::TransferWriteOp write, 588 MaskingOpInterface maskOp, 589 PatternRewriter &rewriter) const override { 590 if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) { 591 return rewriter.notifyMatchFailure( 592 write, "vector type is greater than max transfer rank"); 593 } 594 if (maskOp) 595 return rewriter.notifyMatchFailure(write, "Masked case not supported"); 596 597 // Permutations are handled by VectorToSCF or 598 // populateVectorTransferPermutationMapLoweringPatterns. 599 if ( // pass-through for the 0-d corner case. 600 !write.getPermutationMap().isMinorIdentity()) 601 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 602 diag << "permutation map is not minor identity: " << write; 603 }); 604 605 auto memRefType = dyn_cast<MemRefType>(write.getShapedType()); 606 if (!memRefType) 607 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 608 diag << "not a memref type: " << write; 609 }); 610 611 // Non-unit strides are handled by VectorToSCF. 612 if (!isLastMemrefDimUnitStride(memRefType)) 613 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 614 diag << "most minor stride is not 1: " << write; 615 }); 616 617 // `vector.store` supports vector types as memref's elements only when the 618 // type of the vector value being written is the same as the element type. 619 auto memrefElTy = memRefType.getElementType(); 620 if (isa<VectorType>(memrefElTy) && memrefElTy != write.getVectorType()) 621 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 622 diag << "elemental type mismatch: " << write; 623 }); 624 625 // Otherwise, element types of the memref and the vector must match. 626 if (!isa<VectorType>(memrefElTy) && 627 memrefElTy != write.getVectorType().getElementType()) 628 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 629 diag << "elemental type mismatch: " << write; 630 }); 631 632 // Out-of-bounds dims are handled by MaterializeTransferMask. 633 if (write.hasOutOfBoundsDim()) 634 return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { 635 diag << "out of bounds dim: " << write; 636 }); 637 if (write.getMask()) { 638 if (write.getVectorType().getRank() != 1) 639 // vector.maskedstore operates on 1-D vectors. 640 return rewriter.notifyMatchFailure( 641 write.getLoc(), [=](Diagnostic &diag) { 642 diag << "vector type is not rank 1, can't create masked store, " 643 "needs VectorToSCF: " 644 << write; 645 }); 646 647 rewriter.create<vector::MaskedStoreOp>( 648 write.getLoc(), write.getSource(), write.getIndices(), 649 write.getMask(), write.getVector()); 650 } else { 651 rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(), 652 write.getSource(), write.getIndices()); 653 } 654 // There's no return value for StoreOps. Use Value() to signal success to 655 // matchAndRewrite. 656 return Value(); 657 } 658 659 std::optional<unsigned> maxTransferRank; 660 }; 661 } // namespace 662 663 void mlir::vector::populateVectorTransferLoweringPatterns( 664 RewritePatternSet &patterns, std::optional<unsigned> maxTransferRank, 665 PatternBenefit benefit) { 666 patterns.add<TransferReadToVectorLoadLowering, 667 TransferWriteToVectorStoreLowering>(patterns.getContext(), 668 maxTransferRank, benefit); 669 patterns 670 .add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>( 671 patterns.getContext(), benefit); 672 } 673