1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/Support/Debug.h" 28 29 #define DEBUG_TYPE "vector-transfer-opt" 30 31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 32 33 using namespace mlir; 34 35 /// Return the ancestor op in the region or nullptr if the region is not 36 /// an ancestor of the op. 37 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 38 for (; op != nullptr && op->getParentRegion() != region; 39 op = op->getParentOp()) 40 ; 41 return op; 42 } 43 44 namespace { 45 46 class TransferOptimization { 47 public: 48 TransferOptimization(RewriterBase &rewriter, Operation *op) 49 : rewriter(rewriter), dominators(op), postDominators(op) {} 50 void deadStoreOp(vector::TransferWriteOp); 51 void storeToLoadForwarding(vector::TransferReadOp); 52 void removeDeadOp() { 53 for (Operation *op : opToErase) 54 rewriter.eraseOp(op); 55 opToErase.clear(); 56 } 57 58 private: 59 RewriterBase &rewriter; 60 bool isReachable(Operation *start, Operation *dest); 61 DominanceInfo dominators; 62 PostDominanceInfo postDominators; 63 std::vector<Operation *> opToErase; 64 }; 65 66 } // namespace 67 /// Return true if there is a path from start operation to dest operation, 68 /// otherwise return false. The operations have to be in the same region. 69 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 70 assert(start->getParentRegion() == dest->getParentRegion() && 71 "This function only works for ops i the same region"); 72 // Simple case where the start op dominate the destination. 73 if (dominators.dominates(start, dest)) 74 return true; 75 Block *startBlock = start->getBlock(); 76 Block *destBlock = dest->getBlock(); 77 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 78 startBlock->succ_end()); 79 SmallPtrSet<Block *, 32> visited; 80 while (!worklist.empty()) { 81 Block *bb = worklist.pop_back_val(); 82 if (!visited.insert(bb).second) 83 continue; 84 if (dominators.dominates(bb, destBlock)) 85 return true; 86 worklist.append(bb->succ_begin(), bb->succ_end()); 87 } 88 return false; 89 } 90 91 /// For transfer_write to overwrite fully another transfer_write must: 92 /// 1. Access the same memref with the same indices and vector type. 93 /// 2. Post-dominate the other transfer_write operation. 94 /// If several candidates are available, one must be post-dominated by all the 95 /// others since they are all post-dominating the same transfer_write. We only 96 /// consider the transfer_write post-dominated by all the other candidates as 97 /// this will be the first transfer_write executed after the potentially dead 98 /// transfer_write. 99 /// If we found such an overwriting transfer_write we know that the original 100 /// transfer_write is dead if all reads that can be reached from the potentially 101 /// dead transfer_write are dominated by the overwriting transfer_write. 102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 103 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 104 << "\n"); 105 llvm::SmallVector<Operation *, 8> blockingAccesses; 106 Operation *firstOverwriteCandidate = nullptr; 107 Value source = write.getSource(); 108 // Skip subview ops. 109 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 110 source = subView.getSource(); 111 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 112 source.getUsers().end()); 113 llvm::SmallDenseSet<Operation *, 32> processed; 114 while (!users.empty()) { 115 Operation *user = users.pop_back_val(); 116 // If the user has already been processed skip. 117 if (!processed.insert(user).second) 118 continue; 119 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 120 users.append(subView->getUsers().begin(), subView->getUsers().end()); 121 continue; 122 } 123 if (isMemoryEffectFree(user)) 124 continue; 125 if (user == write.getOperation()) 126 continue; 127 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 128 // Check candidate that can override the store. 129 if (write.getSource() == nextWrite.getSource() && 130 checkSameValueWAW(nextWrite, write) && 131 postDominators.postDominates(nextWrite, write)) { 132 if (firstOverwriteCandidate == nullptr || 133 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 134 firstOverwriteCandidate = nextWrite; 135 else 136 assert( 137 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 138 continue; 139 } 140 } 141 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { 142 // Don't need to consider disjoint accesses. 143 if (vector::isDisjointTransferSet( 144 cast<VectorTransferOpInterface>(write.getOperation()), 145 cast<VectorTransferOpInterface>(transferOp.getOperation()), 146 /*testDynamicValueUsingBounds=*/true)) 147 continue; 148 } 149 blockingAccesses.push_back(user); 150 } 151 if (firstOverwriteCandidate == nullptr) 152 return; 153 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 154 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 155 assert(writeAncestor && 156 "write op should be recursively part of the top region"); 157 158 for (Operation *access : blockingAccesses) { 159 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access); 160 // TODO: if the access and write have the same ancestor we could recurse in 161 // the region to know if the access is reachable with more precision. 162 if (accessAncestor == nullptr || 163 !isReachable(writeAncestor, accessAncestor)) 164 continue; 165 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { 166 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " 167 << *accessAncestor << "\n"); 168 return; 169 } 170 } 171 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 172 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 173 opToErase.push_back(write.getOperation()); 174 } 175 176 /// A transfer_write candidate to storeToLoad forwarding must: 177 /// 1. Access the same memref with the same indices and vector type as the 178 /// transfer_read. 179 /// 2. Dominate the transfer_read operation. 180 /// If several candidates are available, one must be dominated by all the others 181 /// since they are all dominating the same transfer_read. We only consider the 182 /// transfer_write dominated by all the other candidates as this will be the 183 /// last transfer_write executed before the transfer_read. 184 /// If we found such a candidate we can do the forwarding if all the other 185 /// potentially aliasing ops that may reach the transfer_read are post-dominated 186 /// by the transfer_write. 187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 188 if (read.hasOutOfBoundsDim()) 189 return; 190 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 191 << "\n"); 192 SmallVector<Operation *, 8> blockingWrites; 193 vector::TransferWriteOp lastwrite = nullptr; 194 Value source = read.getSource(); 195 // Skip subview ops. 196 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 197 source = subView.getSource(); 198 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 199 source.getUsers().end()); 200 llvm::SmallDenseSet<Operation *, 32> processed; 201 while (!users.empty()) { 202 Operation *user = users.pop_back_val(); 203 // If the user has already been processed skip. 204 if (!processed.insert(user).second) 205 continue; 206 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 207 users.append(subView->getUsers().begin(), subView->getUsers().end()); 208 continue; 209 } 210 if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) { 211 users.append(collapsed->getUsers().begin(), collapsed->getUsers().end()); 212 continue; 213 } 214 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) 215 continue; 216 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 217 // If there is a write, but we can prove that it is disjoint we can ignore 218 // the write. 219 if (vector::isDisjointTransferSet( 220 cast<VectorTransferOpInterface>(write.getOperation()), 221 cast<VectorTransferOpInterface>(read.getOperation()), 222 /*testDynamicValueUsingBounds=*/true)) 223 continue; 224 if (write.getSource() == read.getSource() && 225 dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 226 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 227 lastwrite = write; 228 else 229 assert(dominators.dominates(write, lastwrite)); 230 continue; 231 } 232 } 233 blockingWrites.push_back(user); 234 } 235 236 if (lastwrite == nullptr) 237 return; 238 239 Region *topRegion = lastwrite->getParentRegion(); 240 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 241 assert(readAncestor && 242 "read op should be recursively part of the top region"); 243 244 for (Operation *write : blockingWrites) { 245 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 246 // TODO: if the store and read have the same ancestor we could recurse in 247 // the region to know if the read is reachable with more precision. 248 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 249 continue; 250 if (!postDominators.postDominates(lastwrite, write)) { 251 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 252 << *write << "\n"); 253 return; 254 } 255 } 256 257 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 258 << " to: " << *read.getOperation() << "\n"); 259 read.replaceAllUsesWith(lastwrite.getVector()); 260 opToErase.push_back(read.getOperation()); 261 } 262 263 /// Returns a copy of `shape` without unit dims. 264 static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) { 265 SmallVector<int64_t> reducedShape; 266 llvm::copy_if(shape, std::back_inserter(reducedShape), 267 [](int64_t dimSize) { return dimSize != 1; }); 268 return reducedShape; 269 } 270 271 /// Converts OpFoldResults to int64_t shape without unit dims. 272 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { 273 SmallVector<int64_t> reducedShape; 274 for (const auto size : mixedSizes) { 275 if (llvm::dyn_cast_if_present<Value>(size)) { 276 reducedShape.push_back(ShapedType::kDynamic); 277 continue; 278 } 279 280 auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue(); 281 if (value == 1) 282 continue; 283 reducedShape.push_back(value.getSExtValue()); 284 } 285 return reducedShape; 286 } 287 288 /// Drops unit dimensions from the input MemRefType. 289 static MemRefType dropUnitDims(MemRefType inputType, 290 ArrayRef<OpFoldResult> offsets, 291 ArrayRef<OpFoldResult> sizes, 292 ArrayRef<OpFoldResult> strides) { 293 auto targetShape = getReducedShape(sizes); 294 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 295 targetShape, inputType, offsets, sizes, strides); 296 return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType)); 297 } 298 299 /// Creates a rank-reducing memref.subview op that drops unit dims from its 300 /// input. Or just returns the input if it was already without unit dims. 301 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 302 mlir::Location loc, 303 Value input) { 304 MemRefType inputType = cast<MemRefType>(input.getType()); 305 SmallVector<OpFoldResult> offsets(inputType.getRank(), 306 rewriter.getIndexAttr(0)); 307 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input); 308 SmallVector<OpFoldResult> strides(inputType.getRank(), 309 rewriter.getIndexAttr(1)); 310 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); 311 312 if (canonicalizeStridedLayout(resultType) == 313 canonicalizeStridedLayout(inputType)) 314 return input; 315 return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets, 316 sizes, strides); 317 } 318 319 /// Returns the number of dims that aren't unit dims. 320 static int getReducedRank(ArrayRef<int64_t> shape) { 321 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 322 } 323 324 /// Trims non-scalable one dimensions from `oldType` and returns the result 325 /// type. 326 static VectorType trimNonScalableUnitDims(VectorType oldType) { 327 SmallVector<int64_t> newShape; 328 SmallVector<bool> newScalableDims; 329 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { 330 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) 331 continue; 332 newShape.push_back(dimSize); 333 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); 334 } 335 return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 336 } 337 338 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. 339 static FailureOr<Value> 340 createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, 341 vector::CreateMaskOp op) { 342 auto type = op.getType(); 343 auto reducedType = trimNonScalableUnitDims(type); 344 if (reducedType.getRank() == type.getRank()) 345 return failure(); 346 347 SmallVector<Value> reducedOperands; 348 for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( 349 type.getShape(), type.getScalableDims(), op.getOperands())) { 350 if (dim == 1 && !dimIsScalable) { 351 // If the mask for the unit dim is not a constant of 1, do nothing. 352 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); 353 if (!constant || (constant.value() != 1)) 354 return failure(); 355 continue; 356 } 357 reducedOperands.push_back(operand); 358 } 359 return rewriter 360 .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) 361 .getResult(); 362 } 363 364 namespace { 365 366 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by 367 /// inserting a memref.subview dropping those unit dims. The vector shapes are 368 /// also reduced accordingly. 369 class TransferReadDropUnitDimsPattern 370 : public OpRewritePattern<vector::TransferReadOp> { 371 using OpRewritePattern::OpRewritePattern; 372 373 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 374 PatternRewriter &rewriter) const override { 375 auto loc = transferReadOp.getLoc(); 376 Value vector = transferReadOp.getVector(); 377 VectorType vectorType = cast<VectorType>(vector.getType()); 378 Value source = transferReadOp.getSource(); 379 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 380 // TODO: support tensor types. 381 if (!sourceType) 382 return failure(); 383 // TODO: generalize this pattern, relax the requirements here. 384 if (transferReadOp.hasOutOfBoundsDim()) 385 return failure(); 386 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 387 return failure(); 388 // Check if the source shape can be further reduced. 389 int reducedRank = getReducedRank(sourceType.getShape()); 390 if (reducedRank == sourceType.getRank()) 391 return failure(); 392 // Check if the reduced vector shape matches the reduced source shape. 393 // Otherwise, this case is not supported yet. 394 auto reducedVectorType = trimNonScalableUnitDims(vectorType); 395 if (reducedRank != reducedVectorType.getRank()) 396 return failure(); 397 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { 398 return getConstantIntValue(v) != static_cast<int64_t>(0); 399 })) 400 return failure(); 401 402 Value maskOp = transferReadOp.getMask(); 403 if (maskOp) { 404 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 405 if (!createMaskOp) 406 return rewriter.notifyMatchFailure( 407 transferReadOp, "unsupported mask op, only 'vector.create_mask' is " 408 "currently supported"); 409 FailureOr<Value> rankReducedCreateMask = 410 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 411 if (failed(rankReducedCreateMask)) 412 return failure(); 413 maskOp = *rankReducedCreateMask; 414 } 415 416 Value reducedShapeSource = 417 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 418 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 419 SmallVector<Value> zeros(reducedRank, c0); 420 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 421 SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 422 auto newTransferReadOp = rewriter.create<vector::TransferReadOp>( 423 loc, reducedVectorType, reducedShapeSource, zeros, identityMap, 424 transferReadOp.getPadding(), maskOp, 425 rewriter.getBoolArrayAttr(inBounds)); 426 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 427 loc, vectorType, newTransferReadOp); 428 rewriter.replaceOp(transferReadOp, shapeCast); 429 430 return success(); 431 } 432 }; 433 434 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) 435 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The 436 /// vector shapes are also reduced accordingly. 437 class TransferWriteDropUnitDimsPattern 438 : public OpRewritePattern<vector::TransferWriteOp> { 439 using OpRewritePattern::OpRewritePattern; 440 441 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 442 PatternRewriter &rewriter) const override { 443 auto loc = transferWriteOp.getLoc(); 444 Value vector = transferWriteOp.getVector(); 445 VectorType vectorType = cast<VectorType>(vector.getType()); 446 Value source = transferWriteOp.getSource(); 447 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 448 // TODO: support tensor type. 449 if (!sourceType || !sourceType.hasStaticShape()) 450 return failure(); 451 if (sourceType.getNumElements() != vectorType.getNumElements()) 452 return failure(); 453 // TODO: generalize this pattern, relax the requirements here. 454 if (transferWriteOp.hasOutOfBoundsDim()) 455 return failure(); 456 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 457 return failure(); 458 // Check if the destination shape can be further reduced. 459 int reducedRank = getReducedRank(sourceType.getShape()); 460 if (reducedRank == sourceType.getRank()) 461 return failure(); 462 // Check if the reduced vector shape matches the reduced destination shape. 463 // Otherwise, this case is not supported yet. 464 int vectorReducedRank = getReducedRank(vectorType.getShape()); 465 if (reducedRank != vectorReducedRank) 466 return failure(); 467 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { 468 return getConstantIntValue(v) != static_cast<int64_t>(0); 469 })) 470 return failure(); 471 Value reducedShapeSource = 472 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 473 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 474 SmallVector<Value> zeros(reducedRank, c0); 475 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 476 VectorType reducedVectorType = VectorType::get( 477 getReducedShape(vectorType.getShape()), vectorType.getElementType()); 478 479 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 480 loc, reducedVectorType, vector); 481 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 482 transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap); 483 484 return success(); 485 } 486 }; 487 488 } // namespace 489 490 /// Return true if the memref type has its inner dimension matching the given 491 /// shape. Otherwise return false. 492 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType, 493 ArrayRef<int64_t> targetShape) { 494 auto shape = memrefType.getShape(); 495 SmallVector<int64_t> strides; 496 int64_t offset; 497 if (!succeeded(getStridesAndOffset(memrefType, strides, offset))) 498 return false; 499 if (strides.back() != 1) 500 return false; 501 strides.pop_back(); 502 int64_t flatDim = 1; 503 for (auto [targetDim, memrefDim, memrefStride] : 504 llvm::reverse(llvm::zip(targetShape, shape, strides))) { 505 flatDim *= memrefDim; 506 if (flatDim != memrefStride || targetDim != memrefDim) 507 return false; 508 } 509 return true; 510 } 511 512 /// Creates a memref.collapse_shape collapsing all inner dimensions of the 513 /// input starting at `firstDimToCollapse`. 514 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, 515 Value input, int64_t firstDimToCollapse) { 516 ShapedType inputType = cast<ShapedType>(input.getType()); 517 if (inputType.getRank() == 1) 518 return input; 519 SmallVector<ReassociationIndices> reassociation; 520 for (int64_t i = 0; i < firstDimToCollapse; ++i) 521 reassociation.push_back(ReassociationIndices{i}); 522 ReassociationIndices collapsedIndices; 523 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) 524 collapsedIndices.push_back(i); 525 reassociation.push_back(collapsedIndices); 526 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); 527 } 528 529 /// Checks that the indices corresponding to dimensions starting at 530 /// `firstDimToCollapse` are constant 0, and writes to `outIndices` 531 /// the truncated indices where `firstDimToCollapse` is now the innermost dim. 532 static LogicalResult 533 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, 534 SmallVector<Value> &outIndices) { 535 int64_t rank = indices.size(); 536 if (firstDimToCollapse >= rank) 537 return failure(); 538 for (int64_t i = firstDimToCollapse; i < rank; ++i) { 539 std::optional<int64_t> cst = getConstantIntValue(indices[i]); 540 if (!cst || cst.value() != 0) 541 return failure(); 542 } 543 outIndices = indices; 544 outIndices.resize(firstDimToCollapse + 1); 545 return success(); 546 } 547 548 namespace { 549 550 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 551 /// memref.collapse_shape on the source so that the resulting 552 /// vector.transfer_read has a 1D source. Requires the source shape to be 553 /// already reduced i.e. without unit dims. 554 class FlattenContiguousRowMajorTransferReadPattern 555 : public OpRewritePattern<vector::TransferReadOp> { 556 using OpRewritePattern::OpRewritePattern; 557 558 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 559 PatternRewriter &rewriter) const override { 560 auto loc = transferReadOp.getLoc(); 561 Value vector = transferReadOp.getVector(); 562 VectorType vectorType = cast<VectorType>(vector.getType()); 563 Value source = transferReadOp.getSource(); 564 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 565 // Contiguity check is valid on tensors only. 566 if (!sourceType) 567 return failure(); 568 if (vectorType.getRank() <= 1) 569 // Already 0D/1D, nothing to do. 570 return failure(); 571 if (!hasMatchingInnerContigousShape( 572 sourceType, 573 vectorType.getShape().take_back(vectorType.getRank() - 1))) 574 return failure(); 575 int64_t firstContiguousInnerDim = 576 sourceType.getRank() - vectorType.getRank(); 577 // TODO: generalize this pattern, relax the requirements here. 578 if (transferReadOp.hasOutOfBoundsDim()) 579 return failure(); 580 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 581 return failure(); 582 if (transferReadOp.getMask()) 583 return failure(); 584 SmallVector<Value> collapsedIndices; 585 if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), 586 firstContiguousInnerDim, 587 collapsedIndices))) 588 return failure(); 589 Value collapsedSource = 590 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 591 MemRefType collapsedSourceType = 592 dyn_cast<MemRefType>(collapsedSource.getType()); 593 int64_t collapsedRank = collapsedSourceType.getRank(); 594 assert(collapsedRank == firstContiguousInnerDim + 1); 595 SmallVector<AffineExpr, 1> dimExprs{ 596 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 597 auto collapsedMap = 598 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 599 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 600 vectorType.getElementType()); 601 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( 602 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); 603 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 604 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 605 transferReadOp, cast<VectorType>(vector.getType()), flatRead); 606 return success(); 607 } 608 }; 609 610 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 611 /// memref.collapse_shape on the source so that the resulting 612 /// vector.transfer_write has a 1D source. Requires the source shape to be 613 /// already reduced i.e. without unit dims. 614 class FlattenContiguousRowMajorTransferWritePattern 615 : public OpRewritePattern<vector::TransferWriteOp> { 616 using OpRewritePattern::OpRewritePattern; 617 618 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 619 PatternRewriter &rewriter) const override { 620 auto loc = transferWriteOp.getLoc(); 621 Value vector = transferWriteOp.getVector(); 622 VectorType vectorType = cast<VectorType>(vector.getType()); 623 Value source = transferWriteOp.getSource(); 624 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 625 // Contiguity check is valid on tensors only. 626 if (!sourceType) 627 return failure(); 628 if (vectorType.getRank() <= 1) 629 // Already 0D/1D, nothing to do. 630 return failure(); 631 if (!hasMatchingInnerContigousShape( 632 sourceType, 633 vectorType.getShape().take_back(vectorType.getRank() - 1))) 634 return failure(); 635 int64_t firstContiguousInnerDim = 636 sourceType.getRank() - vectorType.getRank(); 637 // TODO: generalize this pattern, relax the requirements here. 638 if (transferWriteOp.hasOutOfBoundsDim()) 639 return failure(); 640 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 641 return failure(); 642 if (transferWriteOp.getMask()) 643 return failure(); 644 SmallVector<Value> collapsedIndices; 645 if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(), 646 firstContiguousInnerDim, 647 collapsedIndices))) 648 return failure(); 649 Value collapsedSource = 650 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 651 MemRefType collapsedSourceType = 652 cast<MemRefType>(collapsedSource.getType()); 653 int64_t collapsedRank = collapsedSourceType.getRank(); 654 assert(collapsedRank == firstContiguousInnerDim + 1); 655 SmallVector<AffineExpr, 1> dimExprs{ 656 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 657 auto collapsedMap = 658 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 659 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 660 vectorType.getElementType()); 661 Value flatVector = 662 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); 663 vector::TransferWriteOp flatWrite = 664 rewriter.create<vector::TransferWriteOp>( 665 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); 666 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 667 rewriter.eraseOp(transferWriteOp); 668 return success(); 669 } 670 }; 671 672 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` 673 /// to `memref.load` patterns. The `match` method is shared for both 674 /// `vector.extract` and `vector.extract_element`. 675 template <class VectorExtractOp> 676 class RewriteScalarExtractOfTransferReadBase 677 : public OpRewritePattern<VectorExtractOp> { 678 using Base = OpRewritePattern<VectorExtractOp>; 679 680 public: 681 RewriteScalarExtractOfTransferReadBase(MLIRContext *context, 682 PatternBenefit benefit, 683 bool allowMultipleUses) 684 : Base::OpRewritePattern(context, benefit), 685 allowMultipleUses(allowMultipleUses) {} 686 687 LogicalResult match(VectorExtractOp extractOp) const override { 688 auto xferOp = 689 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); 690 if (!xferOp) 691 return failure(); 692 // Check that we are extracting a scalar and not a sub-vector. 693 if (isa<VectorType>(extractOp.getResult().getType())) 694 return failure(); 695 // If multiple uses are not allowed, check if xfer has a single use. 696 if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) 697 return failure(); 698 // If multiple uses are allowed, check if all the xfer uses are extract ops. 699 if (allowMultipleUses && 700 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { 701 return isa<vector::ExtractOp, vector::ExtractElementOp>( 702 use.getOwner()); 703 })) 704 return failure(); 705 // Mask not supported. 706 if (xferOp.getMask()) 707 return failure(); 708 // Map not supported. 709 if (!xferOp.getPermutationMap().isMinorIdentity()) 710 return failure(); 711 // Cannot rewrite if the indices may be out of bounds. 712 if (xferOp.hasOutOfBoundsDim()) 713 return failure(); 714 return success(); 715 } 716 717 private: 718 bool allowMultipleUses; 719 }; 720 721 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 722 /// 723 /// All the users of the transfer op must be either `vector.extractelement` or 724 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 725 /// transfer ops with any number of users. Otherwise, rewrite only if the 726 /// extract op is the single user of the transfer op. Rewriting a single 727 /// vector load with multiple scalar loads may negatively affect performance. 728 class RewriteScalarExtractElementOfTransferRead 729 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { 730 using RewriteScalarExtractOfTransferReadBase:: 731 RewriteScalarExtractOfTransferReadBase; 732 733 void rewrite(vector::ExtractElementOp extractOp, 734 PatternRewriter &rewriter) const override { 735 // Construct scalar load. 736 auto loc = extractOp.getLoc(); 737 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 738 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 739 xferOp.getIndices().end()); 740 if (extractOp.getPosition()) { 741 AffineExpr sym0, sym1; 742 bindSymbols(extractOp.getContext(), sym0, sym1); 743 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 744 rewriter, loc, sym0 + sym1, 745 {newIndices[newIndices.size() - 1], extractOp.getPosition()}); 746 if (ofr.is<Value>()) { 747 newIndices[newIndices.size() - 1] = ofr.get<Value>(); 748 } else { 749 newIndices[newIndices.size() - 1] = 750 rewriter.create<arith::ConstantIndexOp>(loc, 751 *getConstantIntValue(ofr)); 752 } 753 } 754 if (isa<MemRefType>(xferOp.getSource().getType())) { 755 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 756 newIndices); 757 } else { 758 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 759 extractOp, xferOp.getSource(), newIndices); 760 } 761 } 762 }; 763 764 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 765 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. 766 /// 767 /// All the users of the transfer op must be either `vector.extractelement` or 768 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 769 /// transfer ops with any number of users. Otherwise, rewrite only if the 770 /// extract op is the single user of the transfer op. Rewriting a single 771 /// vector load with multiple scalar loads may negatively affect performance. 772 class RewriteScalarExtractOfTransferRead 773 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { 774 using RewriteScalarExtractOfTransferReadBase:: 775 RewriteScalarExtractOfTransferReadBase; 776 777 void rewrite(vector::ExtractOp extractOp, 778 PatternRewriter &rewriter) const override { 779 // Construct scalar load. 780 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 781 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 782 xferOp.getIndices().end()); 783 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 784 assert(pos.is<Attribute>() && "Unexpected non-constant index"); 785 int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt(); 786 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; 787 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 788 rewriter, extractOp.getLoc(), 789 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); 790 if (ofr.is<Value>()) { 791 newIndices[idx] = ofr.get<Value>(); 792 } else { 793 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( 794 extractOp.getLoc(), *getConstantIntValue(ofr)); 795 } 796 } 797 if (isa<MemRefType>(xferOp.getSource().getType())) { 798 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 799 newIndices); 800 } else { 801 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 802 extractOp, xferOp.getSource(), newIndices); 803 } 804 } 805 }; 806 807 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) 808 /// to memref.store. 809 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { 810 using OpRewritePattern::OpRewritePattern; 811 812 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 813 PatternRewriter &rewriter) const override { 814 // Must be a scalar write. 815 auto vecType = xferOp.getVectorType(); 816 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) 817 return failure(); 818 // Mask not supported. 819 if (xferOp.getMask()) 820 return failure(); 821 // Map not supported. 822 if (!xferOp.getPermutationMap().isMinorIdentity()) 823 return failure(); 824 // Only float and integer element types are supported. 825 Value scalar; 826 if (vecType.getRank() == 0) { 827 // vector.extract does not support vector<f32> etc., so use 828 // vector.extractelement instead. 829 scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(), 830 xferOp.getVector()); 831 } else { 832 SmallVector<int64_t> pos(vecType.getRank(), 0); 833 scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(), 834 xferOp.getVector(), pos); 835 } 836 // Construct a scalar store. 837 if (isa<MemRefType>(xferOp.getSource().getType())) { 838 rewriter.replaceOpWithNewOp<memref::StoreOp>( 839 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 840 } else { 841 rewriter.replaceOpWithNewOp<tensor::InsertOp>( 842 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 843 } 844 return success(); 845 } 846 }; 847 848 } // namespace 849 850 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, 851 Operation *rootOp) { 852 TransferOptimization opt(rewriter, rootOp); 853 // Run store to load forwarding first since it can expose more dead store 854 // opportunity. 855 rootOp->walk([&](vector::TransferReadOp read) { 856 if (isa<MemRefType>(read.getShapedType())) 857 opt.storeToLoadForwarding(read); 858 }); 859 opt.removeDeadOp(); 860 rootOp->walk([&](vector::TransferWriteOp write) { 861 if (isa<MemRefType>(write.getShapedType())) 862 opt.deadStoreOp(write); 863 }); 864 opt.removeDeadOp(); 865 } 866 867 void mlir::vector::populateScalarVectorTransferLoweringPatterns( 868 RewritePatternSet &patterns, PatternBenefit benefit, 869 bool allowMultipleUses) { 870 patterns.add<RewriteScalarExtractElementOfTransferRead, 871 RewriteScalarExtractOfTransferRead>(patterns.getContext(), 872 benefit, allowMultipleUses); 873 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit); 874 } 875 876 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 877 RewritePatternSet &patterns, PatternBenefit benefit) { 878 patterns 879 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 880 patterns.getContext(), benefit); 881 populateShapeCastFoldingPatterns(patterns); 882 } 883 884 void mlir::vector::populateFlattenVectorTransferPatterns( 885 RewritePatternSet &patterns, PatternBenefit benefit) { 886 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 887 FlattenContiguousRowMajorTransferWritePattern>( 888 patterns.getContext(), benefit); 889 populateShapeCastFoldingPatterns(patterns, benefit); 890 } 891