1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Utils/IndexingUtils.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 24 #include "mlir/IR/Dominance.h" 25 #include "mlir/Interfaces/SideEffectInterfaces.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "llvm/Support/Debug.h" 29 30 #define DEBUG_TYPE "vector-transfer-opt" 31 32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 33 34 using namespace mlir; 35 36 /// Return the ancestor op in the region or nullptr if the region is not 37 /// an ancestor of the op. 38 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 39 for (; op != nullptr && op->getParentRegion() != region; 40 op = op->getParentOp()) 41 ; 42 return op; 43 } 44 45 namespace { 46 47 class TransferOptimization { 48 public: 49 TransferOptimization(RewriterBase &rewriter, Operation *op) 50 : rewriter(rewriter), dominators(op), postDominators(op) {} 51 void deadStoreOp(vector::TransferWriteOp); 52 void storeToLoadForwarding(vector::TransferReadOp); 53 void removeDeadOp() { 54 for (Operation *op : opToErase) 55 rewriter.eraseOp(op); 56 opToErase.clear(); 57 } 58 59 private: 60 RewriterBase &rewriter; 61 bool isReachable(Operation *start, Operation *dest); 62 DominanceInfo dominators; 63 PostDominanceInfo postDominators; 64 std::vector<Operation *> opToErase; 65 }; 66 67 } // namespace 68 /// Return true if there is a path from start operation to dest operation, 69 /// otherwise return false. The operations have to be in the same region. 70 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 71 assert(start->getParentRegion() == dest->getParentRegion() && 72 "This function only works for ops i the same region"); 73 // Simple case where the start op dominate the destination. 74 if (dominators.dominates(start, dest)) 75 return true; 76 return start->getBlock()->isReachable(dest->getBlock()); 77 } 78 79 /// For transfer_write to overwrite fully another transfer_write must: 80 /// 1. Access the same memref with the same indices and vector type. 81 /// 2. Post-dominate the other transfer_write operation. 82 /// If several candidates are available, one must be post-dominated by all the 83 /// others since they are all post-dominating the same transfer_write. We only 84 /// consider the transfer_write post-dominated by all the other candidates as 85 /// this will be the first transfer_write executed after the potentially dead 86 /// transfer_write. 87 /// If we found such an overwriting transfer_write we know that the original 88 /// transfer_write is dead if all reads that can be reached from the potentially 89 /// dead transfer_write are dominated by the overwriting transfer_write. 90 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 91 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 92 << "\n"); 93 llvm::SmallVector<Operation *, 8> blockingAccesses; 94 Operation *firstOverwriteCandidate = nullptr; 95 Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource())); 96 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 97 source.getUsers().end()); 98 llvm::SmallDenseSet<Operation *, 32> processed; 99 while (!users.empty()) { 100 Operation *user = users.pop_back_val(); 101 // If the user has already been processed skip. 102 if (!processed.insert(user).second) 103 continue; 104 if (isa<ViewLikeOpInterface>(user)) { 105 users.append(user->getUsers().begin(), user->getUsers().end()); 106 continue; 107 } 108 if (isMemoryEffectFree(user)) 109 continue; 110 if (user == write.getOperation()) 111 continue; 112 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 113 // Check candidate that can override the store. 114 if (memref::isSameViewOrTrivialAlias( 115 cast<MemrefValue>(nextWrite.getSource()), 116 cast<MemrefValue>(write.getSource())) && 117 checkSameValueWAW(nextWrite, write) && 118 postDominators.postDominates(nextWrite, write)) { 119 if (firstOverwriteCandidate == nullptr || 120 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 121 firstOverwriteCandidate = nextWrite; 122 else 123 assert( 124 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 125 continue; 126 } 127 } 128 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { 129 // Don't need to consider disjoint accesses. 130 if (vector::isDisjointTransferSet( 131 cast<VectorTransferOpInterface>(write.getOperation()), 132 cast<VectorTransferOpInterface>(transferOp.getOperation()), 133 /*testDynamicValueUsingBounds=*/true)) 134 continue; 135 } 136 blockingAccesses.push_back(user); 137 } 138 if (firstOverwriteCandidate == nullptr) 139 return; 140 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 141 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 142 assert(writeAncestor && 143 "write op should be recursively part of the top region"); 144 145 for (Operation *access : blockingAccesses) { 146 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access); 147 // TODO: if the access and write have the same ancestor we could recurse in 148 // the region to know if the access is reachable with more precision. 149 if (accessAncestor == nullptr || 150 !isReachable(writeAncestor, accessAncestor)) 151 continue; 152 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { 153 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " 154 << *accessAncestor << "\n"); 155 return; 156 } 157 } 158 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 159 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 160 opToErase.push_back(write.getOperation()); 161 } 162 163 /// A transfer_write candidate to storeToLoad forwarding must: 164 /// 1. Access the same memref with the same indices and vector type as the 165 /// transfer_read. 166 /// 2. Dominate the transfer_read operation. 167 /// If several candidates are available, one must be dominated by all the others 168 /// since they are all dominating the same transfer_read. We only consider the 169 /// transfer_write dominated by all the other candidates as this will be the 170 /// last transfer_write executed before the transfer_read. 171 /// If we found such a candidate we can do the forwarding if all the other 172 /// potentially aliasing ops that may reach the transfer_read are post-dominated 173 /// by the transfer_write. 174 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 175 if (read.hasOutOfBoundsDim()) 176 return; 177 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 178 << "\n"); 179 SmallVector<Operation *, 8> blockingWrites; 180 vector::TransferWriteOp lastwrite = nullptr; 181 Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource())); 182 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 183 source.getUsers().end()); 184 llvm::SmallDenseSet<Operation *, 32> processed; 185 while (!users.empty()) { 186 Operation *user = users.pop_back_val(); 187 // If the user has already been processed skip. 188 if (!processed.insert(user).second) 189 continue; 190 if (isa<ViewLikeOpInterface>(user)) { 191 users.append(user->getUsers().begin(), user->getUsers().end()); 192 continue; 193 } 194 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) 195 continue; 196 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 197 // If there is a write, but we can prove that it is disjoint we can ignore 198 // the write. 199 if (vector::isDisjointTransferSet( 200 cast<VectorTransferOpInterface>(write.getOperation()), 201 cast<VectorTransferOpInterface>(read.getOperation()), 202 /*testDynamicValueUsingBounds=*/true)) 203 continue; 204 if (memref::isSameViewOrTrivialAlias( 205 cast<MemrefValue>(read.getSource()), 206 cast<MemrefValue>(write.getSource())) && 207 dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 208 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 209 lastwrite = write; 210 else 211 assert(dominators.dominates(write, lastwrite)); 212 continue; 213 } 214 } 215 blockingWrites.push_back(user); 216 } 217 218 if (lastwrite == nullptr) 219 return; 220 221 Region *topRegion = lastwrite->getParentRegion(); 222 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 223 assert(readAncestor && 224 "read op should be recursively part of the top region"); 225 226 for (Operation *write : blockingWrites) { 227 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 228 // TODO: if the store and read have the same ancestor we could recurse in 229 // the region to know if the read is reachable with more precision. 230 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 231 continue; 232 if (!postDominators.postDominates(lastwrite, write)) { 233 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 234 << *write << "\n"); 235 return; 236 } 237 } 238 239 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 240 << " to: " << *read.getOperation() << "\n"); 241 read.replaceAllUsesWith(lastwrite.getVector()); 242 opToErase.push_back(read.getOperation()); 243 } 244 245 /// Converts OpFoldResults to int64_t shape without unit dims. 246 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { 247 SmallVector<int64_t> reducedShape; 248 for (const auto size : mixedSizes) { 249 if (llvm::dyn_cast_if_present<Value>(size)) { 250 reducedShape.push_back(ShapedType::kDynamic); 251 continue; 252 } 253 254 auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue(); 255 if (value == 1) 256 continue; 257 reducedShape.push_back(value.getSExtValue()); 258 } 259 return reducedShape; 260 } 261 262 /// Drops unit dimensions from the input MemRefType. 263 static MemRefType dropUnitDims(MemRefType inputType, 264 ArrayRef<OpFoldResult> offsets, 265 ArrayRef<OpFoldResult> sizes, 266 ArrayRef<OpFoldResult> strides) { 267 auto targetShape = getReducedShape(sizes); 268 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 269 targetShape, inputType, offsets, sizes, strides); 270 return cast<MemRefType>(rankReducedType).canonicalizeStridedLayout(); 271 } 272 273 /// Creates a rank-reducing memref.subview op that drops unit dims from its 274 /// input. Or just returns the input if it was already without unit dims. 275 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 276 mlir::Location loc, 277 Value input) { 278 MemRefType inputType = cast<MemRefType>(input.getType()); 279 SmallVector<OpFoldResult> offsets(inputType.getRank(), 280 rewriter.getIndexAttr(0)); 281 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input); 282 SmallVector<OpFoldResult> strides(inputType.getRank(), 283 rewriter.getIndexAttr(1)); 284 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); 285 286 if (resultType.canonicalizeStridedLayout() == 287 inputType.canonicalizeStridedLayout()) 288 return input; 289 return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets, 290 sizes, strides); 291 } 292 293 /// Returns the number of dims that aren't unit dims. 294 static int getReducedRank(ArrayRef<int64_t> shape) { 295 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 296 } 297 298 /// Trims non-scalable one dimensions from `oldType` and returns the result 299 /// type. 300 static VectorType trimNonScalableUnitDims(VectorType oldType) { 301 SmallVector<int64_t> newShape; 302 SmallVector<bool> newScalableDims; 303 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { 304 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) 305 continue; 306 newShape.push_back(dimSize); 307 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); 308 } 309 return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 310 } 311 312 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. 313 static FailureOr<Value> 314 createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, 315 vector::CreateMaskOp op) { 316 auto type = op.getType(); 317 VectorType reducedType = trimNonScalableUnitDims(type); 318 if (reducedType.getRank() == type.getRank()) 319 return failure(); 320 321 SmallVector<Value> reducedOperands; 322 for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( 323 type.getShape(), type.getScalableDims(), op.getOperands())) { 324 if (dim == 1 && !dimIsScalable) { 325 // If the mask for the unit dim is not a constant of 1, do nothing. 326 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); 327 if (!constant || (constant.value() != 1)) 328 return failure(); 329 continue; 330 } 331 reducedOperands.push_back(operand); 332 } 333 return rewriter 334 .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) 335 .getResult(); 336 } 337 338 namespace { 339 340 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by 341 /// inserting a memref.subview dropping those unit dims. The vector shapes are 342 /// also reduced accordingly. 343 class TransferReadDropUnitDimsPattern 344 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> { 345 using MaskableOpRewritePattern::MaskableOpRewritePattern; 346 347 FailureOr<Value> 348 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp, 349 vector::MaskingOpInterface maskingOp, 350 PatternRewriter &rewriter) const override { 351 auto loc = transferReadOp.getLoc(); 352 Value vector = transferReadOp.getVector(); 353 VectorType vectorType = cast<VectorType>(vector.getType()); 354 Value source = transferReadOp.getSource(); 355 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 356 // TODO: support tensor types. 357 if (!sourceType) 358 return failure(); 359 // TODO: generalize this pattern, relax the requirements here. 360 if (transferReadOp.hasOutOfBoundsDim()) 361 return failure(); 362 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 363 return failure(); 364 // Check if the source shape can be further reduced. 365 int reducedRank = getReducedRank(sourceType.getShape()); 366 if (reducedRank == sourceType.getRank()) 367 return failure(); 368 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail 369 // out. 370 if (reducedRank == 0 && maskingOp) 371 return failure(); 372 // Check if the reduced vector shape matches the reduced source shape. 373 // Otherwise, this case is not supported yet. 374 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 375 if (reducedRank != reducedVectorType.getRank()) 376 return failure(); 377 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { 378 return getConstantIntValue(v) != static_cast<int64_t>(0); 379 })) 380 return failure(); 381 382 Value maskOp = transferReadOp.getMask(); 383 if (maskOp) { 384 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 385 if (!createMaskOp) 386 return rewriter.notifyMatchFailure( 387 transferReadOp, "unsupported mask op, only 'vector.create_mask' is " 388 "currently supported"); 389 FailureOr<Value> rankReducedCreateMask = 390 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 391 if (failed(rankReducedCreateMask)) 392 return failure(); 393 maskOp = *rankReducedCreateMask; 394 } 395 396 Value reducedShapeSource = 397 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 398 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 399 SmallVector<Value> zeros(reducedRank, c0); 400 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 401 SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 402 Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>( 403 loc, reducedVectorType, reducedShapeSource, zeros, identityMap, 404 transferReadOp.getPadding(), maskOp, 405 rewriter.getBoolArrayAttr(inBounds)); 406 407 if (maskingOp) { 408 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( 409 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), 410 maskingOp.getMask()); 411 newTransferReadOp = mlir::vector::maskOperation( 412 rewriter, newTransferReadOp, shapeCastMask); 413 } 414 415 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 416 loc, vectorType, newTransferReadOp->getResults()[0]); 417 418 return shapeCast; 419 } 420 }; 421 422 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) 423 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The 424 /// vector shapes are also reduced accordingly. 425 class TransferWriteDropUnitDimsPattern 426 : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> { 427 using MaskableOpRewritePattern::MaskableOpRewritePattern; 428 429 FailureOr<Value> 430 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp, 431 vector::MaskingOpInterface maskingOp, 432 PatternRewriter &rewriter) const override { 433 auto loc = transferWriteOp.getLoc(); 434 Value vector = transferWriteOp.getVector(); 435 VectorType vectorType = cast<VectorType>(vector.getType()); 436 Value source = transferWriteOp.getSource(); 437 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 438 // TODO: support tensor type. 439 if (!sourceType) 440 return failure(); 441 // TODO: generalize this pattern, relax the requirements here. 442 if (transferWriteOp.hasOutOfBoundsDim()) 443 return failure(); 444 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 445 return failure(); 446 // Check if the destination shape can be further reduced. 447 int reducedRank = getReducedRank(sourceType.getShape()); 448 if (reducedRank == sourceType.getRank()) 449 return failure(); 450 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail 451 // out. 452 if (reducedRank == 0 && maskingOp) 453 return failure(); 454 // Check if the reduced vector shape matches the reduced destination shape. 455 // Otherwise, this case is not supported yet. 456 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 457 if (reducedRank != reducedVectorType.getRank()) 458 return failure(); 459 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { 460 return getConstantIntValue(v) != static_cast<int64_t>(0); 461 })) 462 return failure(); 463 464 Value maskOp = transferWriteOp.getMask(); 465 if (maskOp) { 466 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 467 if (!createMaskOp) 468 return rewriter.notifyMatchFailure( 469 transferWriteOp, 470 "unsupported mask op, only 'vector.create_mask' is " 471 "currently supported"); 472 FailureOr<Value> rankReducedCreateMask = 473 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 474 if (failed(rankReducedCreateMask)) 475 return failure(); 476 maskOp = *rankReducedCreateMask; 477 } 478 Value reducedShapeSource = 479 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 480 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 481 SmallVector<Value> zeros(reducedRank, c0); 482 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 483 SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 484 auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>( 485 loc, reducedVectorType, vector); 486 Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>( 487 loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap, 488 maskOp, rewriter.getBoolArrayAttr(inBounds)); 489 490 if (maskingOp) { 491 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( 492 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), 493 maskingOp.getMask()); 494 newXferWrite = 495 mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask); 496 } 497 498 if (transferWriteOp.hasPureTensorSemantics()) 499 return newXferWrite->getResults()[0]; 500 501 // With Memref semantics, there's no return value. Use empty value to signal 502 // success. 503 return Value(); 504 } 505 }; 506 507 } // namespace 508 509 /// Creates a memref.collapse_shape collapsing all inner dimensions of the 510 /// input starting at `firstDimToCollapse`. 511 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, 512 Value input, int64_t firstDimToCollapse) { 513 ShapedType inputType = cast<ShapedType>(input.getType()); 514 if (inputType.getRank() == 1) 515 return input; 516 SmallVector<ReassociationIndices> reassociation; 517 for (int64_t i = 0; i < firstDimToCollapse; ++i) 518 reassociation.push_back(ReassociationIndices{i}); 519 ReassociationIndices collapsedIndices; 520 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) 521 collapsedIndices.push_back(i); 522 reassociation.push_back(collapsedIndices); 523 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); 524 } 525 526 /// Returns the new indices that collapses the inner dimensions starting from 527 /// the `firstDimToCollapse` dimension. 528 static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter, 529 Location loc, 530 ArrayRef<int64_t> shape, 531 ValueRange indices, 532 int64_t firstDimToCollapse) { 533 assert(firstDimToCollapse < static_cast<int64_t>(indices.size())); 534 535 // If all the collapsed indices are zero then no extra logic is needed. 536 // Otherwise, a new offset/index has to be computed. 537 SmallVector<Value> indicesAfterCollapsing( 538 indices.begin(), indices.begin() + firstDimToCollapse); 539 SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse, 540 indices.end()); 541 if (llvm::all_of(indicesToCollapse, isZeroIndex)) { 542 indicesAfterCollapsing.push_back(indicesToCollapse[0]); 543 return indicesAfterCollapsing; 544 } 545 546 // Compute the remaining trailing index/offset required for reading from 547 // the collapsed memref: 548 // 549 // offset = 0 550 // for (i = firstDimToCollapse; i < outputRank; ++i) 551 // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] 552 // 553 // For this example: 554 // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) : 555 // memref<1x43x2xi32>, vector<1x2xi32> 556 // which would be collapsed to: 557 // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) : 558 // memref<1x86xi32>, vector<2xi32> 559 // one would get the following offset: 560 // %offset = %arg0 * 43 561 OpFoldResult collapsedOffset = 562 rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult(); 563 564 auto collapsedStrides = computeSuffixProduct( 565 ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end())); 566 567 // Compute the collapsed offset. 568 auto &&[collapsedExpr, collapsedVals] = 569 computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse); 570 collapsedOffset = affine::makeComposedFoldedAffineApply( 571 rewriter, loc, collapsedExpr, collapsedVals); 572 573 if (auto value = dyn_cast<Value>(collapsedOffset)) { 574 indicesAfterCollapsing.push_back(value); 575 } else { 576 indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>( 577 loc, *getConstantIntValue(collapsedOffset))); 578 } 579 580 return indicesAfterCollapsing; 581 } 582 583 namespace { 584 585 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 586 /// memref.collapse_shape on the source so that the resulting 587 /// vector.transfer_read has a 1D source. Requires the source shape to be 588 /// already reduced i.e. without unit dims. 589 /// 590 /// If `targetVectorBitwidth` is provided, the flattening will only happen if 591 /// the trailing dimension of the vector read is smaller than the provided 592 /// bitwidth. 593 class FlattenContiguousRowMajorTransferReadPattern 594 : public OpRewritePattern<vector::TransferReadOp> { 595 public: 596 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, 597 unsigned vectorBitwidth, 598 PatternBenefit benefit) 599 : OpRewritePattern<vector::TransferReadOp>(context, benefit), 600 targetVectorBitwidth(vectorBitwidth) {} 601 602 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 603 PatternRewriter &rewriter) const override { 604 auto loc = transferReadOp.getLoc(); 605 Value vector = transferReadOp.getVector(); 606 VectorType vectorType = cast<VectorType>(vector.getType()); 607 auto source = transferReadOp.getSource(); 608 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 609 610 // 0. Check pre-conditions 611 // Contiguity check is valid on tensors only. 612 if (!sourceType) 613 return failure(); 614 // If this is already 0D/1D, there's nothing to do. 615 if (vectorType.getRank() <= 1) 616 return failure(); 617 if (!vectorType.getElementType().isSignlessIntOrFloat()) 618 return failure(); 619 unsigned trailingVectorDimBitwidth = 620 vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); 621 if (trailingVectorDimBitwidth >= targetVectorBitwidth) 622 return failure(); 623 if (!vector::isContiguousSlice(sourceType, vectorType)) 624 return failure(); 625 // TODO: generalize this pattern, relax the requirements here. 626 if (transferReadOp.hasOutOfBoundsDim()) 627 return failure(); 628 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 629 return failure(); 630 if (transferReadOp.getMask()) 631 return failure(); 632 633 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); 634 635 // 1. Collapse the source memref 636 Value collapsedSource = 637 collapseInnerDims(rewriter, loc, source, firstDimToCollapse); 638 MemRefType collapsedSourceType = 639 cast<MemRefType>(collapsedSource.getType()); 640 int64_t collapsedRank = collapsedSourceType.getRank(); 641 assert(collapsedRank == firstDimToCollapse + 1); 642 643 // 2. Generate input args for a new vector.transfer_read that will read 644 // from the collapsed memref. 645 // 2.1. New dim exprs + affine map 646 SmallVector<AffineExpr, 1> dimExprs{ 647 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; 648 auto collapsedMap = 649 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 650 651 // 2.2 New indices 652 SmallVector<Value> collapsedIndices = 653 getCollapsedIndices(rewriter, loc, sourceType.getShape(), 654 transferReadOp.getIndices(), firstDimToCollapse); 655 656 // 3. Create new vector.transfer_read that reads from the collapsed memref 657 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 658 vectorType.getElementType()); 659 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( 660 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); 661 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 662 663 // 4. Replace the old transfer_read with the new one reading from the 664 // collapsed shape 665 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 666 transferReadOp, cast<VectorType>(vector.getType()), flatRead); 667 return success(); 668 } 669 670 private: 671 // Minimum bitwidth that the trailing vector dimension should have after 672 // flattening. 673 unsigned targetVectorBitwidth; 674 }; 675 676 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 677 /// memref.collapse_shape on the source so that the resulting 678 /// vector.transfer_write has a 1D source. Requires the source shape to be 679 /// already reduced i.e. without unit dims. 680 /// 681 /// If `targetVectorBitwidth` is provided, the flattening will only happen if 682 /// the trailing dimension of the vector read is smaller than the provided 683 /// bitwidth. 684 class FlattenContiguousRowMajorTransferWritePattern 685 : public OpRewritePattern<vector::TransferWriteOp> { 686 public: 687 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, 688 unsigned vectorBitwidth, 689 PatternBenefit benefit) 690 : OpRewritePattern<vector::TransferWriteOp>(context, benefit), 691 targetVectorBitwidth(vectorBitwidth) {} 692 693 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 694 PatternRewriter &rewriter) const override { 695 auto loc = transferWriteOp.getLoc(); 696 Value vector = transferWriteOp.getVector(); 697 VectorType vectorType = cast<VectorType>(vector.getType()); 698 Value source = transferWriteOp.getSource(); 699 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 700 701 // 0. Check pre-conditions 702 // Contiguity check is valid on tensors only. 703 if (!sourceType) 704 return failure(); 705 // If this is already 0D/1D, there's nothing to do. 706 if (vectorType.getRank() <= 1) 707 // Already 0D/1D, nothing to do. 708 return failure(); 709 if (!vectorType.getElementType().isSignlessIntOrFloat()) 710 return failure(); 711 unsigned trailingVectorDimBitwidth = 712 vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); 713 if (trailingVectorDimBitwidth >= targetVectorBitwidth) 714 return failure(); 715 if (!vector::isContiguousSlice(sourceType, vectorType)) 716 return failure(); 717 // TODO: generalize this pattern, relax the requirements here. 718 if (transferWriteOp.hasOutOfBoundsDim()) 719 return failure(); 720 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 721 return failure(); 722 if (transferWriteOp.getMask()) 723 return failure(); 724 725 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); 726 727 // 1. Collapse the source memref 728 Value collapsedSource = 729 collapseInnerDims(rewriter, loc, source, firstDimToCollapse); 730 MemRefType collapsedSourceType = 731 cast<MemRefType>(collapsedSource.getType()); 732 int64_t collapsedRank = collapsedSourceType.getRank(); 733 assert(collapsedRank == firstDimToCollapse + 1); 734 735 // 2. Generate input args for a new vector.transfer_read that will read 736 // from the collapsed memref. 737 // 2.1. New dim exprs + affine map 738 SmallVector<AffineExpr, 1> dimExprs{ 739 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; 740 auto collapsedMap = 741 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 742 743 // 2.2 New indices 744 SmallVector<Value> collapsedIndices = 745 getCollapsedIndices(rewriter, loc, sourceType.getShape(), 746 transferWriteOp.getIndices(), firstDimToCollapse); 747 748 // 3. Create new vector.transfer_write that writes to the collapsed memref 749 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 750 vectorType.getElementType()); 751 Value flatVector = 752 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); 753 vector::TransferWriteOp flatWrite = 754 rewriter.create<vector::TransferWriteOp>( 755 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); 756 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 757 758 // 4. Replace the old transfer_write with the new one writing the 759 // collapsed shape 760 rewriter.eraseOp(transferWriteOp); 761 return success(); 762 } 763 764 private: 765 // Minimum bitwidth that the trailing vector dimension should have after 766 // flattening. 767 unsigned targetVectorBitwidth; 768 }; 769 770 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` 771 /// to `memref.load` patterns. The `match` method is shared for both 772 /// `vector.extract` and `vector.extract_element`. 773 template <class VectorExtractOp> 774 class RewriteScalarExtractOfTransferReadBase 775 : public OpRewritePattern<VectorExtractOp> { 776 using Base = OpRewritePattern<VectorExtractOp>; 777 778 public: 779 RewriteScalarExtractOfTransferReadBase(MLIRContext *context, 780 PatternBenefit benefit, 781 bool allowMultipleUses) 782 : Base::OpRewritePattern(context, benefit), 783 allowMultipleUses(allowMultipleUses) {} 784 785 LogicalResult match(VectorExtractOp extractOp) const override { 786 auto xferOp = 787 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); 788 if (!xferOp) 789 return failure(); 790 // Check that we are extracting a scalar and not a sub-vector. 791 if (isa<VectorType>(extractOp.getResult().getType())) 792 return failure(); 793 // If multiple uses are not allowed, check if xfer has a single use. 794 if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) 795 return failure(); 796 // If multiple uses are allowed, check if all the xfer uses are extract ops. 797 if (allowMultipleUses && 798 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { 799 return isa<vector::ExtractOp, vector::ExtractElementOp>( 800 use.getOwner()); 801 })) 802 return failure(); 803 // Mask not supported. 804 if (xferOp.getMask()) 805 return failure(); 806 // Map not supported. 807 if (!xferOp.getPermutationMap().isMinorIdentity()) 808 return failure(); 809 // Cannot rewrite if the indices may be out of bounds. 810 if (xferOp.hasOutOfBoundsDim()) 811 return failure(); 812 return success(); 813 } 814 815 private: 816 bool allowMultipleUses; 817 }; 818 819 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 820 /// 821 /// All the users of the transfer op must be either `vector.extractelement` or 822 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 823 /// transfer ops with any number of users. Otherwise, rewrite only if the 824 /// extract op is the single user of the transfer op. Rewriting a single 825 /// vector load with multiple scalar loads may negatively affect performance. 826 class RewriteScalarExtractElementOfTransferRead 827 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { 828 using RewriteScalarExtractOfTransferReadBase:: 829 RewriteScalarExtractOfTransferReadBase; 830 831 void rewrite(vector::ExtractElementOp extractOp, 832 PatternRewriter &rewriter) const override { 833 // Construct scalar load. 834 auto loc = extractOp.getLoc(); 835 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 836 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 837 xferOp.getIndices().end()); 838 if (extractOp.getPosition()) { 839 AffineExpr sym0, sym1; 840 bindSymbols(extractOp.getContext(), sym0, sym1); 841 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 842 rewriter, loc, sym0 + sym1, 843 {newIndices[newIndices.size() - 1], extractOp.getPosition()}); 844 if (auto value = dyn_cast<Value>(ofr)) { 845 newIndices[newIndices.size() - 1] = value; 846 } else { 847 newIndices[newIndices.size() - 1] = 848 rewriter.create<arith::ConstantIndexOp>(loc, 849 *getConstantIntValue(ofr)); 850 } 851 } 852 if (isa<MemRefType>(xferOp.getSource().getType())) { 853 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 854 newIndices); 855 } else { 856 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 857 extractOp, xferOp.getSource(), newIndices); 858 } 859 } 860 }; 861 862 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 863 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. 864 /// 865 /// All the users of the transfer op must be either `vector.extractelement` or 866 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 867 /// transfer ops with any number of users. Otherwise, rewrite only if the 868 /// extract op is the single user of the transfer op. Rewriting a single 869 /// vector load with multiple scalar loads may negatively affect performance. 870 class RewriteScalarExtractOfTransferRead 871 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { 872 using RewriteScalarExtractOfTransferReadBase:: 873 RewriteScalarExtractOfTransferReadBase; 874 875 void rewrite(vector::ExtractOp extractOp, 876 PatternRewriter &rewriter) const override { 877 // Construct scalar load. 878 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 879 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 880 xferOp.getIndices().end()); 881 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 882 assert(isa<Attribute>(pos) && "Unexpected non-constant index"); 883 int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt(); 884 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; 885 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 886 rewriter, extractOp.getLoc(), 887 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); 888 if (auto value = dyn_cast<Value>(ofr)) { 889 newIndices[idx] = value; 890 } else { 891 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( 892 extractOp.getLoc(), *getConstantIntValue(ofr)); 893 } 894 } 895 if (isa<MemRefType>(xferOp.getSource().getType())) { 896 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 897 newIndices); 898 } else { 899 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 900 extractOp, xferOp.getSource(), newIndices); 901 } 902 } 903 }; 904 905 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) 906 /// to memref.store. 907 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { 908 using OpRewritePattern::OpRewritePattern; 909 910 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 911 PatternRewriter &rewriter) const override { 912 // Must be a scalar write. 913 auto vecType = xferOp.getVectorType(); 914 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) 915 return failure(); 916 // Mask not supported. 917 if (xferOp.getMask()) 918 return failure(); 919 // Map not supported. 920 if (!xferOp.getPermutationMap().isMinorIdentity()) 921 return failure(); 922 // Only float and integer element types are supported. 923 Value scalar; 924 if (vecType.getRank() == 0) { 925 // vector.extract does not support vector<f32> etc., so use 926 // vector.extractelement instead. 927 scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(), 928 xferOp.getVector()); 929 } else { 930 SmallVector<int64_t> pos(vecType.getRank(), 0); 931 scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(), 932 xferOp.getVector(), pos); 933 } 934 // Construct a scalar store. 935 if (isa<MemRefType>(xferOp.getSource().getType())) { 936 rewriter.replaceOpWithNewOp<memref::StoreOp>( 937 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 938 } else { 939 rewriter.replaceOpWithNewOp<tensor::InsertOp>( 940 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 941 } 942 return success(); 943 } 944 }; 945 946 } // namespace 947 948 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, 949 Operation *rootOp) { 950 TransferOptimization opt(rewriter, rootOp); 951 // Run store to load forwarding first since it can expose more dead store 952 // opportunity. 953 rootOp->walk([&](vector::TransferReadOp read) { 954 if (isa<MemRefType>(read.getShapedType())) 955 opt.storeToLoadForwarding(read); 956 }); 957 opt.removeDeadOp(); 958 rootOp->walk([&](vector::TransferWriteOp write) { 959 if (isa<MemRefType>(write.getShapedType())) 960 opt.deadStoreOp(write); 961 }); 962 opt.removeDeadOp(); 963 } 964 965 void mlir::vector::populateScalarVectorTransferLoweringPatterns( 966 RewritePatternSet &patterns, PatternBenefit benefit, 967 bool allowMultipleUses) { 968 patterns.add<RewriteScalarExtractElementOfTransferRead, 969 RewriteScalarExtractOfTransferRead>(patterns.getContext(), 970 benefit, allowMultipleUses); 971 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit); 972 } 973 974 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 975 RewritePatternSet &patterns, PatternBenefit benefit) { 976 patterns 977 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 978 patterns.getContext(), benefit); 979 populateShapeCastFoldingPatterns(patterns); 980 } 981 982 void mlir::vector::populateFlattenVectorTransferPatterns( 983 RewritePatternSet &patterns, unsigned targetVectorBitwidth, 984 PatternBenefit benefit) { 985 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 986 FlattenContiguousRowMajorTransferWritePattern>( 987 patterns.getContext(), targetVectorBitwidth, benefit); 988 populateShapeCastFoldingPatterns(patterns, benefit); 989 populateDropUnitDimWithShapeCastPatterns(patterns, benefit); 990 } 991