1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/Support/Debug.h" 28 29 #define DEBUG_TYPE "vector-transfer-opt" 30 31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 32 33 using namespace mlir; 34 35 /// Return the ancestor op in the region or nullptr if the region is not 36 /// an ancestor of the op. 37 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 38 for (; op != nullptr && op->getParentRegion() != region; 39 op = op->getParentOp()) 40 ; 41 return op; 42 } 43 44 namespace { 45 46 class TransferOptimization { 47 public: 48 TransferOptimization(RewriterBase &rewriter, Operation *op) 49 : rewriter(rewriter), dominators(op), postDominators(op) {} 50 void deadStoreOp(vector::TransferWriteOp); 51 void storeToLoadForwarding(vector::TransferReadOp); 52 void removeDeadOp() { 53 for (Operation *op : opToErase) 54 rewriter.eraseOp(op); 55 opToErase.clear(); 56 } 57 58 private: 59 RewriterBase &rewriter; 60 bool isReachable(Operation *start, Operation *dest); 61 DominanceInfo dominators; 62 PostDominanceInfo postDominators; 63 std::vector<Operation *> opToErase; 64 }; 65 66 } // namespace 67 /// Return true if there is a path from start operation to dest operation, 68 /// otherwise return false. The operations have to be in the same region. 69 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 70 assert(start->getParentRegion() == dest->getParentRegion() && 71 "This function only works for ops i the same region"); 72 // Simple case where the start op dominate the destination. 73 if (dominators.dominates(start, dest)) 74 return true; 75 Block *startBlock = start->getBlock(); 76 Block *destBlock = dest->getBlock(); 77 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 78 startBlock->succ_end()); 79 SmallPtrSet<Block *, 32> visited; 80 while (!worklist.empty()) { 81 Block *bb = worklist.pop_back_val(); 82 if (!visited.insert(bb).second) 83 continue; 84 if (dominators.dominates(bb, destBlock)) 85 return true; 86 worklist.append(bb->succ_begin(), bb->succ_end()); 87 } 88 return false; 89 } 90 91 /// For transfer_write to overwrite fully another transfer_write must: 92 /// 1. Access the same memref with the same indices and vector type. 93 /// 2. Post-dominate the other transfer_write operation. 94 /// If several candidates are available, one must be post-dominated by all the 95 /// others since they are all post-dominating the same transfer_write. We only 96 /// consider the transfer_write post-dominated by all the other candidates as 97 /// this will be the first transfer_write executed after the potentially dead 98 /// transfer_write. 99 /// If we found such an overwriting transfer_write we know that the original 100 /// transfer_write is dead if all reads that can be reached from the potentially 101 /// dead transfer_write are dominated by the overwriting transfer_write. 102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 103 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 104 << "\n"); 105 llvm::SmallVector<Operation *, 8> blockingAccesses; 106 Operation *firstOverwriteCandidate = nullptr; 107 Value source = write.getSource(); 108 // Skip subview ops. 109 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 110 source = subView.getSource(); 111 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 112 source.getUsers().end()); 113 llvm::SmallDenseSet<Operation *, 32> processed; 114 while (!users.empty()) { 115 Operation *user = users.pop_back_val(); 116 // If the user has already been processed skip. 117 if (!processed.insert(user).second) 118 continue; 119 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 120 users.append(subView->getUsers().begin(), subView->getUsers().end()); 121 continue; 122 } 123 if (isMemoryEffectFree(user)) 124 continue; 125 if (user == write.getOperation()) 126 continue; 127 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 128 // Check candidate that can override the store. 129 if (write.getSource() == nextWrite.getSource() && 130 checkSameValueWAW(nextWrite, write) && 131 postDominators.postDominates(nextWrite, write)) { 132 if (firstOverwriteCandidate == nullptr || 133 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 134 firstOverwriteCandidate = nextWrite; 135 else 136 assert( 137 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 138 continue; 139 } 140 } 141 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { 142 // Don't need to consider disjoint accesses. 143 if (vector::isDisjointTransferSet( 144 cast<VectorTransferOpInterface>(write.getOperation()), 145 cast<VectorTransferOpInterface>(transferOp.getOperation()))) 146 continue; 147 } 148 blockingAccesses.push_back(user); 149 } 150 if (firstOverwriteCandidate == nullptr) 151 return; 152 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 153 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 154 assert(writeAncestor && 155 "write op should be recursively part of the top region"); 156 157 for (Operation *access : blockingAccesses) { 158 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access); 159 // TODO: if the access and write have the same ancestor we could recurse in 160 // the region to know if the access is reachable with more precision. 161 if (accessAncestor == nullptr || 162 !isReachable(writeAncestor, accessAncestor)) 163 continue; 164 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { 165 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " 166 << *accessAncestor << "\n"); 167 return; 168 } 169 } 170 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 171 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 172 opToErase.push_back(write.getOperation()); 173 } 174 175 /// A transfer_write candidate to storeToLoad forwarding must: 176 /// 1. Access the same memref with the same indices and vector type as the 177 /// transfer_read. 178 /// 2. Dominate the transfer_read operation. 179 /// If several candidates are available, one must be dominated by all the others 180 /// since they are all dominating the same transfer_read. We only consider the 181 /// transfer_write dominated by all the other candidates as this will be the 182 /// last transfer_write executed before the transfer_read. 183 /// If we found such a candidate we can do the forwarding if all the other 184 /// potentially aliasing ops that may reach the transfer_read are post-dominated 185 /// by the transfer_write. 186 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 187 if (read.hasOutOfBoundsDim()) 188 return; 189 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 190 << "\n"); 191 SmallVector<Operation *, 8> blockingWrites; 192 vector::TransferWriteOp lastwrite = nullptr; 193 Value source = read.getSource(); 194 // Skip subview ops. 195 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 196 source = subView.getSource(); 197 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 198 source.getUsers().end()); 199 llvm::SmallDenseSet<Operation *, 32> processed; 200 while (!users.empty()) { 201 Operation *user = users.pop_back_val(); 202 // If the user has already been processed skip. 203 if (!processed.insert(user).second) 204 continue; 205 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 206 users.append(subView->getUsers().begin(), subView->getUsers().end()); 207 continue; 208 } 209 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) 210 continue; 211 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 212 // If there is a write, but we can prove that it is disjoint we can ignore 213 // the write. 214 if (vector::isDisjointTransferSet( 215 cast<VectorTransferOpInterface>(write.getOperation()), 216 cast<VectorTransferOpInterface>(read.getOperation()))) 217 continue; 218 if (write.getSource() == read.getSource() && 219 dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 220 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 221 lastwrite = write; 222 else 223 assert(dominators.dominates(write, lastwrite)); 224 continue; 225 } 226 } 227 blockingWrites.push_back(user); 228 } 229 230 if (lastwrite == nullptr) 231 return; 232 233 Region *topRegion = lastwrite->getParentRegion(); 234 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 235 assert(readAncestor && 236 "read op should be recursively part of the top region"); 237 238 for (Operation *write : blockingWrites) { 239 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 240 // TODO: if the store and read have the same ancestor we could recurse in 241 // the region to know if the read is reachable with more precision. 242 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 243 continue; 244 if (!postDominators.postDominates(lastwrite, write)) { 245 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 246 << *write << "\n"); 247 return; 248 } 249 } 250 251 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 252 << " to: " << *read.getOperation() << "\n"); 253 read.replaceAllUsesWith(lastwrite.getVector()); 254 opToErase.push_back(read.getOperation()); 255 } 256 257 /// Drops unit dimensions from the input MemRefType. 258 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets, 259 ArrayRef<int64_t> sizes, 260 ArrayRef<int64_t> strides) { 261 SmallVector<int64_t> targetShape = llvm::to_vector( 262 llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); 263 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 264 targetShape, inputType, offsets, sizes, strides); 265 return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType)); 266 } 267 268 /// Creates a rank-reducing memref.subview op that drops unit dims from its 269 /// input. Or just returns the input if it was already without unit dims. 270 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 271 mlir::Location loc, 272 Value input) { 273 MemRefType inputType = cast<MemRefType>(input.getType()); 274 assert(inputType.hasStaticShape()); 275 SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0); 276 SmallVector<int64_t> subViewStrides(inputType.getRank(), 1); 277 ArrayRef<int64_t> subViewSizes = inputType.getShape(); 278 MemRefType resultType = 279 dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides); 280 if (canonicalizeStridedLayout(resultType) == 281 canonicalizeStridedLayout(inputType)) 282 return input; 283 return rewriter.create<memref::SubViewOp>( 284 loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides); 285 } 286 287 /// Returns the number of dims that aren't unit dims. 288 static int getReducedRank(ArrayRef<int64_t> shape) { 289 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 290 } 291 292 /// Returns a copy of `shape` without unit dims. 293 static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) { 294 SmallVector<int64_t> reducedShape; 295 llvm::copy_if(shape, std::back_inserter(reducedShape), 296 [](int64_t dimSize) { return dimSize != 1; }); 297 return reducedShape; 298 } 299 300 namespace { 301 302 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by 303 /// inserting a memref.subview dropping those unit dims. The vector shapes are 304 /// also reduced accordingly. 305 class TransferReadDropUnitDimsPattern 306 : public OpRewritePattern<vector::TransferReadOp> { 307 using OpRewritePattern::OpRewritePattern; 308 309 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 310 PatternRewriter &rewriter) const override { 311 auto loc = transferReadOp.getLoc(); 312 Value vector = transferReadOp.getVector(); 313 VectorType vectorType = cast<VectorType>(vector.getType()); 314 Value source = transferReadOp.getSource(); 315 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 316 // TODO: support tensor types. 317 if (!sourceType || !sourceType.hasStaticShape()) 318 return failure(); 319 if (sourceType.getNumElements() != vectorType.getNumElements()) 320 return failure(); 321 // TODO: generalize this pattern, relax the requirements here. 322 if (transferReadOp.hasOutOfBoundsDim()) 323 return failure(); 324 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 325 return failure(); 326 // Check if the source shape can be further reduced. 327 int reducedRank = getReducedRank(sourceType.getShape()); 328 if (reducedRank == sourceType.getRank()) 329 return failure(); 330 // Check if the reduced vector shape matches the reduced source shape. 331 // Otherwise, this case is not supported yet. 332 int vectorReducedRank = getReducedRank(vectorType.getShape()); 333 if (reducedRank != vectorReducedRank) 334 return failure(); 335 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { 336 return getConstantIntValue(v) != static_cast<int64_t>(0); 337 })) 338 return failure(); 339 Value reducedShapeSource = 340 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 341 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 342 SmallVector<Value> zeros(reducedRank, c0); 343 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 344 auto reducedVectorType = VectorType::get( 345 getReducedShape(vectorType.getShape()), vectorType.getElementType()); 346 347 auto newTransferReadOp = rewriter.create<vector::TransferReadOp>( 348 loc, reducedVectorType, reducedShapeSource, zeros, identityMap); 349 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 350 loc, vectorType, newTransferReadOp); 351 rewriter.replaceOp(transferReadOp, shapeCast); 352 353 return success(); 354 } 355 }; 356 357 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) 358 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The 359 /// vector shapes are also reduced accordingly. 360 class TransferWriteDropUnitDimsPattern 361 : public OpRewritePattern<vector::TransferWriteOp> { 362 using OpRewritePattern::OpRewritePattern; 363 364 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 365 PatternRewriter &rewriter) const override { 366 auto loc = transferWriteOp.getLoc(); 367 Value vector = transferWriteOp.getVector(); 368 VectorType vectorType = cast<VectorType>(vector.getType()); 369 Value source = transferWriteOp.getSource(); 370 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 371 // TODO: support tensor type. 372 if (!sourceType || !sourceType.hasStaticShape()) 373 return failure(); 374 if (sourceType.getNumElements() != vectorType.getNumElements()) 375 return failure(); 376 // TODO: generalize this pattern, relax the requirements here. 377 if (transferWriteOp.hasOutOfBoundsDim()) 378 return failure(); 379 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 380 return failure(); 381 // Check if the destination shape can be further reduced. 382 int reducedRank = getReducedRank(sourceType.getShape()); 383 if (reducedRank == sourceType.getRank()) 384 return failure(); 385 // Check if the reduced vector shape matches the reduced destination shape. 386 // Otherwise, this case is not supported yet. 387 int vectorReducedRank = getReducedRank(vectorType.getShape()); 388 if (reducedRank != vectorReducedRank) 389 return failure(); 390 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { 391 return getConstantIntValue(v) != static_cast<int64_t>(0); 392 })) 393 return failure(); 394 Value reducedShapeSource = 395 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 396 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 397 SmallVector<Value> zeros(reducedRank, c0); 398 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 399 VectorType reducedVectorType = VectorType::get( 400 getReducedShape(vectorType.getShape()), vectorType.getElementType()); 401 402 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 403 loc, reducedVectorType, vector); 404 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 405 transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap); 406 407 return success(); 408 } 409 }; 410 411 } // namespace 412 413 /// Return true if the memref type has its inner dimension matching the given 414 /// shape. Otherwise return false. 415 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType, 416 ArrayRef<int64_t> targetShape) { 417 auto shape = memrefType.getShape(); 418 SmallVector<int64_t> strides; 419 int64_t offset; 420 if (!succeeded(getStridesAndOffset(memrefType, strides, offset))) 421 return false; 422 if (strides.back() != 1) 423 return false; 424 strides.pop_back(); 425 int64_t flatDim = 1; 426 for (auto [targetDim, memrefDim, memrefStride] : 427 llvm::reverse(llvm::zip(targetShape, shape, strides))) { 428 flatDim *= memrefDim; 429 if (flatDim != memrefStride || targetDim != memrefDim) 430 return false; 431 } 432 return true; 433 } 434 435 /// Creates a memref.collapse_shape collapsing all inner dimensions of the 436 /// input starting at `firstDimToCollapse`. 437 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, 438 Value input, int64_t firstDimToCollapse) { 439 ShapedType inputType = cast<ShapedType>(input.getType()); 440 if (inputType.getRank() == 1) 441 return input; 442 SmallVector<ReassociationIndices> reassociation; 443 for (int64_t i = 0; i < firstDimToCollapse; ++i) 444 reassociation.push_back(ReassociationIndices{i}); 445 ReassociationIndices collapsedIndices; 446 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) 447 collapsedIndices.push_back(i); 448 reassociation.push_back(collapsedIndices); 449 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); 450 } 451 452 /// Checks that the indices corresponding to dimensions starting at 453 /// `firstDimToCollapse` are constant 0, and writes to `outIndices` 454 /// the truncated indices where `firstDimToCollapse` is now the innermost dim. 455 static LogicalResult 456 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, 457 SmallVector<Value> &outIndices) { 458 int64_t rank = indices.size(); 459 if (firstDimToCollapse >= rank) 460 return failure(); 461 for (int64_t i = firstDimToCollapse; i < rank; ++i) { 462 std::optional<int64_t> cst = getConstantIntValue(indices[i]); 463 if (!cst || cst.value() != 0) 464 return failure(); 465 } 466 outIndices = indices; 467 outIndices.resize(firstDimToCollapse + 1); 468 return success(); 469 } 470 471 namespace { 472 473 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 474 /// memref.collapse_shape on the source so that the resulting 475 /// vector.transfer_read has a 1D source. Requires the source shape to be 476 /// already reduced i.e. without unit dims. 477 class FlattenContiguousRowMajorTransferReadPattern 478 : public OpRewritePattern<vector::TransferReadOp> { 479 using OpRewritePattern::OpRewritePattern; 480 481 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 482 PatternRewriter &rewriter) const override { 483 auto loc = transferReadOp.getLoc(); 484 Value vector = transferReadOp.getVector(); 485 VectorType vectorType = cast<VectorType>(vector.getType()); 486 Value source = transferReadOp.getSource(); 487 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 488 // Contiguity check is valid on tensors only. 489 if (!sourceType) 490 return failure(); 491 if (vectorType.getRank() <= 1) 492 // Already 0D/1D, nothing to do. 493 return failure(); 494 if (!hasMatchingInnerContigousShape( 495 sourceType, 496 vectorType.getShape().take_back(vectorType.getRank() - 1))) 497 return failure(); 498 int64_t firstContiguousInnerDim = 499 sourceType.getRank() - vectorType.getRank(); 500 // TODO: generalize this pattern, relax the requirements here. 501 if (transferReadOp.hasOutOfBoundsDim()) 502 return failure(); 503 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 504 return failure(); 505 if (transferReadOp.getMask()) 506 return failure(); 507 SmallVector<Value> collapsedIndices; 508 if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), 509 firstContiguousInnerDim, 510 collapsedIndices))) 511 return failure(); 512 Value collapsedSource = 513 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 514 MemRefType collapsedSourceType = 515 dyn_cast<MemRefType>(collapsedSource.getType()); 516 int64_t collapsedRank = collapsedSourceType.getRank(); 517 assert(collapsedRank == firstContiguousInnerDim + 1); 518 SmallVector<AffineExpr, 1> dimExprs{ 519 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 520 auto collapsedMap = 521 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 522 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 523 vectorType.getElementType()); 524 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( 525 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); 526 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 527 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 528 transferReadOp, cast<VectorType>(vector.getType()), flatRead); 529 return success(); 530 } 531 }; 532 533 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 534 /// memref.collapse_shape on the source so that the resulting 535 /// vector.transfer_write has a 1D source. Requires the source shape to be 536 /// already reduced i.e. without unit dims. 537 class FlattenContiguousRowMajorTransferWritePattern 538 : public OpRewritePattern<vector::TransferWriteOp> { 539 using OpRewritePattern::OpRewritePattern; 540 541 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 542 PatternRewriter &rewriter) const override { 543 auto loc = transferWriteOp.getLoc(); 544 Value vector = transferWriteOp.getVector(); 545 VectorType vectorType = cast<VectorType>(vector.getType()); 546 Value source = transferWriteOp.getSource(); 547 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 548 // Contiguity check is valid on tensors only. 549 if (!sourceType) 550 return failure(); 551 if (vectorType.getRank() <= 1) 552 // Already 0D/1D, nothing to do. 553 return failure(); 554 if (!hasMatchingInnerContigousShape( 555 sourceType, 556 vectorType.getShape().take_back(vectorType.getRank() - 1))) 557 return failure(); 558 int64_t firstContiguousInnerDim = 559 sourceType.getRank() - vectorType.getRank(); 560 // TODO: generalize this pattern, relax the requirements here. 561 if (transferWriteOp.hasOutOfBoundsDim()) 562 return failure(); 563 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 564 return failure(); 565 if (transferWriteOp.getMask()) 566 return failure(); 567 SmallVector<Value> collapsedIndices; 568 if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(), 569 firstContiguousInnerDim, 570 collapsedIndices))) 571 return failure(); 572 Value collapsedSource = 573 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 574 MemRefType collapsedSourceType = 575 cast<MemRefType>(collapsedSource.getType()); 576 int64_t collapsedRank = collapsedSourceType.getRank(); 577 assert(collapsedRank == firstContiguousInnerDim + 1); 578 SmallVector<AffineExpr, 1> dimExprs{ 579 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 580 auto collapsedMap = 581 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 582 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 583 vectorType.getElementType()); 584 Value flatVector = 585 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); 586 vector::TransferWriteOp flatWrite = 587 rewriter.create<vector::TransferWriteOp>( 588 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); 589 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 590 rewriter.eraseOp(transferWriteOp); 591 return success(); 592 } 593 }; 594 595 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` 596 /// to `memref.load` patterns. The `match` method is shared for both 597 /// `vector.extract` and `vector.extract_element`. 598 template <class VectorExtractOp> 599 class RewriteScalarExtractOfTransferReadBase 600 : public OpRewritePattern<VectorExtractOp> { 601 using Base = OpRewritePattern<VectorExtractOp>; 602 603 public: 604 RewriteScalarExtractOfTransferReadBase(MLIRContext *context, 605 PatternBenefit benefit, 606 bool allowMultipleUses) 607 : Base::OpRewritePattern(context, benefit), 608 allowMultipleUses(allowMultipleUses) {} 609 610 LogicalResult match(VectorExtractOp extractOp) const override { 611 auto xferOp = 612 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); 613 if (!xferOp) 614 return failure(); 615 // Check that we are extracting a scalar and not a sub-vector. 616 if (isa<VectorType>(extractOp.getResult().getType())) 617 return failure(); 618 // If multiple uses are not allowed, check if xfer has a single use. 619 if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) 620 return failure(); 621 // If multiple uses are allowed, check if all the xfer uses are extract ops. 622 if (allowMultipleUses && 623 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { 624 return isa<vector::ExtractOp, vector::ExtractElementOp>( 625 use.getOwner()); 626 })) 627 return failure(); 628 // Mask not supported. 629 if (xferOp.getMask()) 630 return failure(); 631 // Map not supported. 632 if (!xferOp.getPermutationMap().isMinorIdentity()) 633 return failure(); 634 // Cannot rewrite if the indices may be out of bounds. 635 if (xferOp.hasOutOfBoundsDim()) 636 return failure(); 637 return success(); 638 } 639 640 private: 641 bool allowMultipleUses; 642 }; 643 644 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 645 /// 646 /// All the users of the transfer op must be either `vector.extractelement` or 647 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 648 /// transfer ops with any number of users. Otherwise, rewrite only if the 649 /// extract op is the single user of the transfer op. Rewriting a single 650 /// vector load with multiple scalar loads may negatively affect performance. 651 class RewriteScalarExtractElementOfTransferRead 652 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { 653 using RewriteScalarExtractOfTransferReadBase:: 654 RewriteScalarExtractOfTransferReadBase; 655 656 void rewrite(vector::ExtractElementOp extractOp, 657 PatternRewriter &rewriter) const override { 658 // Construct scalar load. 659 auto loc = extractOp.getLoc(); 660 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 661 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 662 xferOp.getIndices().end()); 663 if (extractOp.getPosition()) { 664 AffineExpr sym0, sym1; 665 bindSymbols(extractOp.getContext(), sym0, sym1); 666 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 667 rewriter, loc, sym0 + sym1, 668 {newIndices[newIndices.size() - 1], extractOp.getPosition()}); 669 if (ofr.is<Value>()) { 670 newIndices[newIndices.size() - 1] = ofr.get<Value>(); 671 } else { 672 newIndices[newIndices.size() - 1] = 673 rewriter.create<arith::ConstantIndexOp>(loc, 674 *getConstantIntValue(ofr)); 675 } 676 } 677 if (isa<MemRefType>(xferOp.getSource().getType())) { 678 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 679 newIndices); 680 } else { 681 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 682 extractOp, xferOp.getSource(), newIndices); 683 } 684 } 685 }; 686 687 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 688 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. 689 /// 690 /// All the users of the transfer op must be either `vector.extractelement` or 691 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 692 /// transfer ops with any number of users. Otherwise, rewrite only if the 693 /// extract op is the single user of the transfer op. Rewriting a single 694 /// vector load with multiple scalar loads may negatively affect performance. 695 class RewriteScalarExtractOfTransferRead 696 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { 697 using RewriteScalarExtractOfTransferReadBase:: 698 RewriteScalarExtractOfTransferReadBase; 699 700 void rewrite(vector::ExtractOp extractOp, 701 PatternRewriter &rewriter) const override { 702 // Construct scalar load. 703 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 704 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 705 xferOp.getIndices().end()); 706 for (const auto &it : llvm::enumerate(extractOp.getPosition())) { 707 int64_t offset = it.value(); 708 int64_t idx = 709 newIndices.size() - extractOp.getPosition().size() + it.index(); 710 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 711 rewriter, extractOp.getLoc(), 712 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); 713 if (ofr.is<Value>()) { 714 newIndices[idx] = ofr.get<Value>(); 715 } else { 716 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( 717 extractOp.getLoc(), *getConstantIntValue(ofr)); 718 } 719 } 720 if (isa<MemRefType>(xferOp.getSource().getType())) { 721 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 722 newIndices); 723 } else { 724 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 725 extractOp, xferOp.getSource(), newIndices); 726 } 727 } 728 }; 729 730 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) 731 /// to memref.store. 732 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { 733 using OpRewritePattern::OpRewritePattern; 734 735 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 736 PatternRewriter &rewriter) const override { 737 // Must be a scalar write. 738 auto vecType = xferOp.getVectorType(); 739 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) 740 return failure(); 741 // Mask not supported. 742 if (xferOp.getMask()) 743 return failure(); 744 // Map not supported. 745 if (!xferOp.getPermutationMap().isMinorIdentity()) 746 return failure(); 747 // Only float and integer element types are supported. 748 Value scalar; 749 if (vecType.getRank() == 0) { 750 // vector.extract does not support vector<f32> etc., so use 751 // vector.extractelement instead. 752 scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(), 753 xferOp.getVector()); 754 } else { 755 SmallVector<int64_t> pos(vecType.getRank(), 0); 756 scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(), 757 xferOp.getVector(), pos); 758 } 759 // Construct a scalar store. 760 if (isa<MemRefType>(xferOp.getSource().getType())) { 761 rewriter.replaceOpWithNewOp<memref::StoreOp>( 762 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 763 } else { 764 rewriter.replaceOpWithNewOp<tensor::InsertOp>( 765 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 766 } 767 return success(); 768 } 769 }; 770 771 } // namespace 772 773 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, 774 Operation *rootOp) { 775 TransferOptimization opt(rewriter, rootOp); 776 // Run store to load forwarding first since it can expose more dead store 777 // opportunity. 778 rootOp->walk([&](vector::TransferReadOp read) { 779 if (isa<MemRefType>(read.getShapedType())) 780 opt.storeToLoadForwarding(read); 781 }); 782 opt.removeDeadOp(); 783 rootOp->walk([&](vector::TransferWriteOp write) { 784 if (isa<MemRefType>(write.getShapedType())) 785 opt.deadStoreOp(write); 786 }); 787 opt.removeDeadOp(); 788 } 789 790 void mlir::vector::populateScalarVectorTransferLoweringPatterns( 791 RewritePatternSet &patterns, PatternBenefit benefit, 792 bool allowMultipleUses) { 793 patterns.add<RewriteScalarExtractElementOfTransferRead, 794 RewriteScalarExtractOfTransferRead>(patterns.getContext(), 795 benefit, allowMultipleUses); 796 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit); 797 } 798 799 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 800 RewritePatternSet &patterns, PatternBenefit benefit) { 801 patterns 802 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 803 patterns.getContext(), benefit); 804 populateShapeCastFoldingPatterns(patterns); 805 } 806 807 void mlir::vector::populateFlattenVectorTransferPatterns( 808 RewritePatternSet &patterns, PatternBenefit benefit) { 809 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 810 FlattenContiguousRowMajorTransferWritePattern>( 811 patterns.getContext(), benefit); 812 populateShapeCastFoldingPatterns(patterns, benefit); 813 } 814