1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/Support/Debug.h" 28 29 #define DEBUG_TYPE "vector-transfer-opt" 30 31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 32 33 using namespace mlir; 34 35 /// Return the ancestor op in the region or nullptr if the region is not 36 /// an ancestor of the op. 37 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 38 for (; op != nullptr && op->getParentRegion() != region; 39 op = op->getParentOp()) 40 ; 41 return op; 42 } 43 44 namespace { 45 46 class TransferOptimization { 47 public: 48 TransferOptimization(RewriterBase &rewriter, Operation *op) 49 : rewriter(rewriter), dominators(op), postDominators(op) {} 50 void deadStoreOp(vector::TransferWriteOp); 51 void storeToLoadForwarding(vector::TransferReadOp); 52 void removeDeadOp() { 53 for (Operation *op : opToErase) 54 rewriter.eraseOp(op); 55 opToErase.clear(); 56 } 57 58 private: 59 RewriterBase &rewriter; 60 bool isReachable(Operation *start, Operation *dest); 61 DominanceInfo dominators; 62 PostDominanceInfo postDominators; 63 std::vector<Operation *> opToErase; 64 }; 65 66 } // namespace 67 /// Return true if there is a path from start operation to dest operation, 68 /// otherwise return false. The operations have to be in the same region. 69 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 70 assert(start->getParentRegion() == dest->getParentRegion() && 71 "This function only works for ops i the same region"); 72 // Simple case where the start op dominate the destination. 73 if (dominators.dominates(start, dest)) 74 return true; 75 Block *startBlock = start->getBlock(); 76 Block *destBlock = dest->getBlock(); 77 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 78 startBlock->succ_end()); 79 SmallPtrSet<Block *, 32> visited; 80 while (!worklist.empty()) { 81 Block *bb = worklist.pop_back_val(); 82 if (!visited.insert(bb).second) 83 continue; 84 if (dominators.dominates(bb, destBlock)) 85 return true; 86 worklist.append(bb->succ_begin(), bb->succ_end()); 87 } 88 return false; 89 } 90 91 /// For transfer_write to overwrite fully another transfer_write must: 92 /// 1. Access the same memref with the same indices and vector type. 93 /// 2. Post-dominate the other transfer_write operation. 94 /// If several candidates are available, one must be post-dominated by all the 95 /// others since they are all post-dominating the same transfer_write. We only 96 /// consider the transfer_write post-dominated by all the other candidates as 97 /// this will be the first transfer_write executed after the potentially dead 98 /// transfer_write. 99 /// If we found such an overwriting transfer_write we know that the original 100 /// transfer_write is dead if all reads that can be reached from the potentially 101 /// dead transfer_write are dominated by the overwriting transfer_write. 102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 103 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 104 << "\n"); 105 llvm::SmallVector<Operation *, 8> blockingAccesses; 106 Operation *firstOverwriteCandidate = nullptr; 107 Value source = write.getSource(); 108 // Skip subview ops. 109 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 110 source = subView.getSource(); 111 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 112 source.getUsers().end()); 113 llvm::SmallDenseSet<Operation *, 32> processed; 114 while (!users.empty()) { 115 Operation *user = users.pop_back_val(); 116 // If the user has already been processed skip. 117 if (!processed.insert(user).second) 118 continue; 119 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 120 users.append(subView->getUsers().begin(), subView->getUsers().end()); 121 continue; 122 } 123 if (isMemoryEffectFree(user)) 124 continue; 125 if (user == write.getOperation()) 126 continue; 127 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 128 // Check candidate that can override the store. 129 if (write.getSource() == nextWrite.getSource() && 130 checkSameValueWAW(nextWrite, write) && 131 postDominators.postDominates(nextWrite, write)) { 132 if (firstOverwriteCandidate == nullptr || 133 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 134 firstOverwriteCandidate = nextWrite; 135 else 136 assert( 137 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 138 continue; 139 } 140 } 141 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { 142 // Don't need to consider disjoint accesses. 143 if (vector::isDisjointTransferSet( 144 cast<VectorTransferOpInterface>(write.getOperation()), 145 cast<VectorTransferOpInterface>(transferOp.getOperation()))) 146 continue; 147 } 148 blockingAccesses.push_back(user); 149 } 150 if (firstOverwriteCandidate == nullptr) 151 return; 152 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 153 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 154 assert(writeAncestor && 155 "write op should be recursively part of the top region"); 156 157 for (Operation *access : blockingAccesses) { 158 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access); 159 // TODO: if the access and write have the same ancestor we could recurse in 160 // the region to know if the access is reachable with more precision. 161 if (accessAncestor == nullptr || 162 !isReachable(writeAncestor, accessAncestor)) 163 continue; 164 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { 165 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " 166 << *accessAncestor << "\n"); 167 return; 168 } 169 } 170 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 171 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 172 opToErase.push_back(write.getOperation()); 173 } 174 175 /// A transfer_write candidate to storeToLoad forwarding must: 176 /// 1. Access the same memref with the same indices and vector type as the 177 /// transfer_read. 178 /// 2. Dominate the transfer_read operation. 179 /// If several candidates are available, one must be dominated by all the others 180 /// since they are all dominating the same transfer_read. We only consider the 181 /// transfer_write dominated by all the other candidates as this will be the 182 /// last transfer_write executed before the transfer_read. 183 /// If we found such a candidate we can do the forwarding if all the other 184 /// potentially aliasing ops that may reach the transfer_read are post-dominated 185 /// by the transfer_write. 186 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 187 if (read.hasOutOfBoundsDim()) 188 return; 189 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 190 << "\n"); 191 SmallVector<Operation *, 8> blockingWrites; 192 vector::TransferWriteOp lastwrite = nullptr; 193 Value source = read.getSource(); 194 // Skip subview ops. 195 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 196 source = subView.getSource(); 197 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 198 source.getUsers().end()); 199 llvm::SmallDenseSet<Operation *, 32> processed; 200 while (!users.empty()) { 201 Operation *user = users.pop_back_val(); 202 // If the user has already been processed skip. 203 if (!processed.insert(user).second) 204 continue; 205 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 206 users.append(subView->getUsers().begin(), subView->getUsers().end()); 207 continue; 208 } 209 if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) { 210 users.append(collapsed->getUsers().begin(), collapsed->getUsers().end()); 211 continue; 212 } 213 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) 214 continue; 215 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 216 // If there is a write, but we can prove that it is disjoint we can ignore 217 // the write. 218 if (vector::isDisjointTransferSet( 219 cast<VectorTransferOpInterface>(write.getOperation()), 220 cast<VectorTransferOpInterface>(read.getOperation()))) 221 continue; 222 if (write.getSource() == read.getSource() && 223 dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 224 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 225 lastwrite = write; 226 else 227 assert(dominators.dominates(write, lastwrite)); 228 continue; 229 } 230 } 231 blockingWrites.push_back(user); 232 } 233 234 if (lastwrite == nullptr) 235 return; 236 237 Region *topRegion = lastwrite->getParentRegion(); 238 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 239 assert(readAncestor && 240 "read op should be recursively part of the top region"); 241 242 for (Operation *write : blockingWrites) { 243 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 244 // TODO: if the store and read have the same ancestor we could recurse in 245 // the region to know if the read is reachable with more precision. 246 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 247 continue; 248 if (!postDominators.postDominates(lastwrite, write)) { 249 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 250 << *write << "\n"); 251 return; 252 } 253 } 254 255 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 256 << " to: " << *read.getOperation() << "\n"); 257 read.replaceAllUsesWith(lastwrite.getVector()); 258 opToErase.push_back(read.getOperation()); 259 } 260 261 /// Drops unit dimensions from the input MemRefType. 262 static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets, 263 ArrayRef<int64_t> sizes, 264 ArrayRef<int64_t> strides) { 265 SmallVector<int64_t> targetShape = llvm::to_vector( 266 llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; })); 267 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 268 targetShape, inputType, offsets, sizes, strides); 269 return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType)); 270 } 271 272 /// Creates a rank-reducing memref.subview op that drops unit dims from its 273 /// input. Or just returns the input if it was already without unit dims. 274 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 275 mlir::Location loc, 276 Value input) { 277 MemRefType inputType = cast<MemRefType>(input.getType()); 278 assert(inputType.hasStaticShape()); 279 SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0); 280 SmallVector<int64_t> subViewStrides(inputType.getRank(), 1); 281 ArrayRef<int64_t> subViewSizes = inputType.getShape(); 282 MemRefType resultType = 283 dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides); 284 if (canonicalizeStridedLayout(resultType) == 285 canonicalizeStridedLayout(inputType)) 286 return input; 287 return rewriter.create<memref::SubViewOp>( 288 loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides); 289 } 290 291 /// Returns the number of dims that aren't unit dims. 292 static int getReducedRank(ArrayRef<int64_t> shape) { 293 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 294 } 295 296 /// Returns a copy of `shape` without unit dims. 297 static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) { 298 SmallVector<int64_t> reducedShape; 299 llvm::copy_if(shape, std::back_inserter(reducedShape), 300 [](int64_t dimSize) { return dimSize != 1; }); 301 return reducedShape; 302 } 303 304 namespace { 305 306 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by 307 /// inserting a memref.subview dropping those unit dims. The vector shapes are 308 /// also reduced accordingly. 309 class TransferReadDropUnitDimsPattern 310 : public OpRewritePattern<vector::TransferReadOp> { 311 using OpRewritePattern::OpRewritePattern; 312 313 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 314 PatternRewriter &rewriter) const override { 315 auto loc = transferReadOp.getLoc(); 316 Value vector = transferReadOp.getVector(); 317 VectorType vectorType = cast<VectorType>(vector.getType()); 318 Value source = transferReadOp.getSource(); 319 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 320 // TODO: support tensor types. 321 if (!sourceType || !sourceType.hasStaticShape()) 322 return failure(); 323 if (sourceType.getNumElements() != vectorType.getNumElements()) 324 return failure(); 325 // TODO: generalize this pattern, relax the requirements here. 326 if (transferReadOp.hasOutOfBoundsDim()) 327 return failure(); 328 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 329 return failure(); 330 // Check if the source shape can be further reduced. 331 int reducedRank = getReducedRank(sourceType.getShape()); 332 if (reducedRank == sourceType.getRank()) 333 return failure(); 334 // Check if the reduced vector shape matches the reduced source shape. 335 // Otherwise, this case is not supported yet. 336 int vectorReducedRank = getReducedRank(vectorType.getShape()); 337 if (reducedRank != vectorReducedRank) 338 return failure(); 339 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { 340 return getConstantIntValue(v) != static_cast<int64_t>(0); 341 })) 342 return failure(); 343 Value reducedShapeSource = 344 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 345 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 346 SmallVector<Value> zeros(reducedRank, c0); 347 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 348 auto reducedVectorType = VectorType::get( 349 getReducedShape(vectorType.getShape()), vectorType.getElementType()); 350 351 auto newTransferReadOp = rewriter.create<vector::TransferReadOp>( 352 loc, reducedVectorType, reducedShapeSource, zeros, identityMap); 353 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 354 loc, vectorType, newTransferReadOp); 355 rewriter.replaceOp(transferReadOp, shapeCast); 356 357 return success(); 358 } 359 }; 360 361 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) 362 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The 363 /// vector shapes are also reduced accordingly. 364 class TransferWriteDropUnitDimsPattern 365 : public OpRewritePattern<vector::TransferWriteOp> { 366 using OpRewritePattern::OpRewritePattern; 367 368 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 369 PatternRewriter &rewriter) const override { 370 auto loc = transferWriteOp.getLoc(); 371 Value vector = transferWriteOp.getVector(); 372 VectorType vectorType = cast<VectorType>(vector.getType()); 373 Value source = transferWriteOp.getSource(); 374 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 375 // TODO: support tensor type. 376 if (!sourceType || !sourceType.hasStaticShape()) 377 return failure(); 378 if (sourceType.getNumElements() != vectorType.getNumElements()) 379 return failure(); 380 // TODO: generalize this pattern, relax the requirements here. 381 if (transferWriteOp.hasOutOfBoundsDim()) 382 return failure(); 383 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 384 return failure(); 385 // Check if the destination shape can be further reduced. 386 int reducedRank = getReducedRank(sourceType.getShape()); 387 if (reducedRank == sourceType.getRank()) 388 return failure(); 389 // Check if the reduced vector shape matches the reduced destination shape. 390 // Otherwise, this case is not supported yet. 391 int vectorReducedRank = getReducedRank(vectorType.getShape()); 392 if (reducedRank != vectorReducedRank) 393 return failure(); 394 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { 395 return getConstantIntValue(v) != static_cast<int64_t>(0); 396 })) 397 return failure(); 398 Value reducedShapeSource = 399 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 400 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 401 SmallVector<Value> zeros(reducedRank, c0); 402 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 403 VectorType reducedVectorType = VectorType::get( 404 getReducedShape(vectorType.getShape()), vectorType.getElementType()); 405 406 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 407 loc, reducedVectorType, vector); 408 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 409 transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap); 410 411 return success(); 412 } 413 }; 414 415 } // namespace 416 417 /// Return true if the memref type has its inner dimension matching the given 418 /// shape. Otherwise return false. 419 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType, 420 ArrayRef<int64_t> targetShape) { 421 auto shape = memrefType.getShape(); 422 SmallVector<int64_t> strides; 423 int64_t offset; 424 if (!succeeded(getStridesAndOffset(memrefType, strides, offset))) 425 return false; 426 if (strides.back() != 1) 427 return false; 428 strides.pop_back(); 429 int64_t flatDim = 1; 430 for (auto [targetDim, memrefDim, memrefStride] : 431 llvm::reverse(llvm::zip(targetShape, shape, strides))) { 432 flatDim *= memrefDim; 433 if (flatDim != memrefStride || targetDim != memrefDim) 434 return false; 435 } 436 return true; 437 } 438 439 /// Creates a memref.collapse_shape collapsing all inner dimensions of the 440 /// input starting at `firstDimToCollapse`. 441 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, 442 Value input, int64_t firstDimToCollapse) { 443 ShapedType inputType = cast<ShapedType>(input.getType()); 444 if (inputType.getRank() == 1) 445 return input; 446 SmallVector<ReassociationIndices> reassociation; 447 for (int64_t i = 0; i < firstDimToCollapse; ++i) 448 reassociation.push_back(ReassociationIndices{i}); 449 ReassociationIndices collapsedIndices; 450 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) 451 collapsedIndices.push_back(i); 452 reassociation.push_back(collapsedIndices); 453 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); 454 } 455 456 /// Checks that the indices corresponding to dimensions starting at 457 /// `firstDimToCollapse` are constant 0, and writes to `outIndices` 458 /// the truncated indices where `firstDimToCollapse` is now the innermost dim. 459 static LogicalResult 460 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, 461 SmallVector<Value> &outIndices) { 462 int64_t rank = indices.size(); 463 if (firstDimToCollapse >= rank) 464 return failure(); 465 for (int64_t i = firstDimToCollapse; i < rank; ++i) { 466 std::optional<int64_t> cst = getConstantIntValue(indices[i]); 467 if (!cst || cst.value() != 0) 468 return failure(); 469 } 470 outIndices = indices; 471 outIndices.resize(firstDimToCollapse + 1); 472 return success(); 473 } 474 475 namespace { 476 477 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 478 /// memref.collapse_shape on the source so that the resulting 479 /// vector.transfer_read has a 1D source. Requires the source shape to be 480 /// already reduced i.e. without unit dims. 481 class FlattenContiguousRowMajorTransferReadPattern 482 : public OpRewritePattern<vector::TransferReadOp> { 483 using OpRewritePattern::OpRewritePattern; 484 485 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 486 PatternRewriter &rewriter) const override { 487 auto loc = transferReadOp.getLoc(); 488 Value vector = transferReadOp.getVector(); 489 VectorType vectorType = cast<VectorType>(vector.getType()); 490 Value source = transferReadOp.getSource(); 491 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 492 // Contiguity check is valid on tensors only. 493 if (!sourceType) 494 return failure(); 495 if (vectorType.getRank() <= 1) 496 // Already 0D/1D, nothing to do. 497 return failure(); 498 if (!hasMatchingInnerContigousShape( 499 sourceType, 500 vectorType.getShape().take_back(vectorType.getRank() - 1))) 501 return failure(); 502 int64_t firstContiguousInnerDim = 503 sourceType.getRank() - vectorType.getRank(); 504 // TODO: generalize this pattern, relax the requirements here. 505 if (transferReadOp.hasOutOfBoundsDim()) 506 return failure(); 507 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 508 return failure(); 509 if (transferReadOp.getMask()) 510 return failure(); 511 SmallVector<Value> collapsedIndices; 512 if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), 513 firstContiguousInnerDim, 514 collapsedIndices))) 515 return failure(); 516 Value collapsedSource = 517 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 518 MemRefType collapsedSourceType = 519 dyn_cast<MemRefType>(collapsedSource.getType()); 520 int64_t collapsedRank = collapsedSourceType.getRank(); 521 assert(collapsedRank == firstContiguousInnerDim + 1); 522 SmallVector<AffineExpr, 1> dimExprs{ 523 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 524 auto collapsedMap = 525 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 526 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 527 vectorType.getElementType()); 528 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( 529 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); 530 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 531 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 532 transferReadOp, cast<VectorType>(vector.getType()), flatRead); 533 return success(); 534 } 535 }; 536 537 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 538 /// memref.collapse_shape on the source so that the resulting 539 /// vector.transfer_write has a 1D source. Requires the source shape to be 540 /// already reduced i.e. without unit dims. 541 class FlattenContiguousRowMajorTransferWritePattern 542 : public OpRewritePattern<vector::TransferWriteOp> { 543 using OpRewritePattern::OpRewritePattern; 544 545 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 546 PatternRewriter &rewriter) const override { 547 auto loc = transferWriteOp.getLoc(); 548 Value vector = transferWriteOp.getVector(); 549 VectorType vectorType = cast<VectorType>(vector.getType()); 550 Value source = transferWriteOp.getSource(); 551 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 552 // Contiguity check is valid on tensors only. 553 if (!sourceType) 554 return failure(); 555 if (vectorType.getRank() <= 1) 556 // Already 0D/1D, nothing to do. 557 return failure(); 558 if (!hasMatchingInnerContigousShape( 559 sourceType, 560 vectorType.getShape().take_back(vectorType.getRank() - 1))) 561 return failure(); 562 int64_t firstContiguousInnerDim = 563 sourceType.getRank() - vectorType.getRank(); 564 // TODO: generalize this pattern, relax the requirements here. 565 if (transferWriteOp.hasOutOfBoundsDim()) 566 return failure(); 567 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 568 return failure(); 569 if (transferWriteOp.getMask()) 570 return failure(); 571 SmallVector<Value> collapsedIndices; 572 if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(), 573 firstContiguousInnerDim, 574 collapsedIndices))) 575 return failure(); 576 Value collapsedSource = 577 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 578 MemRefType collapsedSourceType = 579 cast<MemRefType>(collapsedSource.getType()); 580 int64_t collapsedRank = collapsedSourceType.getRank(); 581 assert(collapsedRank == firstContiguousInnerDim + 1); 582 SmallVector<AffineExpr, 1> dimExprs{ 583 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 584 auto collapsedMap = 585 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 586 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 587 vectorType.getElementType()); 588 Value flatVector = 589 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); 590 vector::TransferWriteOp flatWrite = 591 rewriter.create<vector::TransferWriteOp>( 592 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); 593 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 594 rewriter.eraseOp(transferWriteOp); 595 return success(); 596 } 597 }; 598 599 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` 600 /// to `memref.load` patterns. The `match` method is shared for both 601 /// `vector.extract` and `vector.extract_element`. 602 template <class VectorExtractOp> 603 class RewriteScalarExtractOfTransferReadBase 604 : public OpRewritePattern<VectorExtractOp> { 605 using Base = OpRewritePattern<VectorExtractOp>; 606 607 public: 608 RewriteScalarExtractOfTransferReadBase(MLIRContext *context, 609 PatternBenefit benefit, 610 bool allowMultipleUses) 611 : Base::OpRewritePattern(context, benefit), 612 allowMultipleUses(allowMultipleUses) {} 613 614 LogicalResult match(VectorExtractOp extractOp) const override { 615 auto xferOp = 616 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); 617 if (!xferOp) 618 return failure(); 619 // Check that we are extracting a scalar and not a sub-vector. 620 if (isa<VectorType>(extractOp.getResult().getType())) 621 return failure(); 622 // If multiple uses are not allowed, check if xfer has a single use. 623 if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) 624 return failure(); 625 // If multiple uses are allowed, check if all the xfer uses are extract ops. 626 if (allowMultipleUses && 627 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { 628 return isa<vector::ExtractOp, vector::ExtractElementOp>( 629 use.getOwner()); 630 })) 631 return failure(); 632 // Mask not supported. 633 if (xferOp.getMask()) 634 return failure(); 635 // Map not supported. 636 if (!xferOp.getPermutationMap().isMinorIdentity()) 637 return failure(); 638 // Cannot rewrite if the indices may be out of bounds. 639 if (xferOp.hasOutOfBoundsDim()) 640 return failure(); 641 return success(); 642 } 643 644 private: 645 bool allowMultipleUses; 646 }; 647 648 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 649 /// 650 /// All the users of the transfer op must be either `vector.extractelement` or 651 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 652 /// transfer ops with any number of users. Otherwise, rewrite only if the 653 /// extract op is the single user of the transfer op. Rewriting a single 654 /// vector load with multiple scalar loads may negatively affect performance. 655 class RewriteScalarExtractElementOfTransferRead 656 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { 657 using RewriteScalarExtractOfTransferReadBase:: 658 RewriteScalarExtractOfTransferReadBase; 659 660 void rewrite(vector::ExtractElementOp extractOp, 661 PatternRewriter &rewriter) const override { 662 // Construct scalar load. 663 auto loc = extractOp.getLoc(); 664 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 665 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 666 xferOp.getIndices().end()); 667 if (extractOp.getPosition()) { 668 AffineExpr sym0, sym1; 669 bindSymbols(extractOp.getContext(), sym0, sym1); 670 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 671 rewriter, loc, sym0 + sym1, 672 {newIndices[newIndices.size() - 1], extractOp.getPosition()}); 673 if (ofr.is<Value>()) { 674 newIndices[newIndices.size() - 1] = ofr.get<Value>(); 675 } else { 676 newIndices[newIndices.size() - 1] = 677 rewriter.create<arith::ConstantIndexOp>(loc, 678 *getConstantIntValue(ofr)); 679 } 680 } 681 if (isa<MemRefType>(xferOp.getSource().getType())) { 682 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 683 newIndices); 684 } else { 685 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 686 extractOp, xferOp.getSource(), newIndices); 687 } 688 } 689 }; 690 691 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 692 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. 693 /// 694 /// All the users of the transfer op must be either `vector.extractelement` or 695 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 696 /// transfer ops with any number of users. Otherwise, rewrite only if the 697 /// extract op is the single user of the transfer op. Rewriting a single 698 /// vector load with multiple scalar loads may negatively affect performance. 699 class RewriteScalarExtractOfTransferRead 700 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { 701 using RewriteScalarExtractOfTransferReadBase:: 702 RewriteScalarExtractOfTransferReadBase; 703 704 void rewrite(vector::ExtractOp extractOp, 705 PatternRewriter &rewriter) const override { 706 // Construct scalar load. 707 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 708 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 709 xferOp.getIndices().end()); 710 for (const auto &it : llvm::enumerate(extractOp.getPosition())) { 711 int64_t offset = it.value(); 712 int64_t idx = 713 newIndices.size() - extractOp.getPosition().size() + it.index(); 714 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 715 rewriter, extractOp.getLoc(), 716 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); 717 if (ofr.is<Value>()) { 718 newIndices[idx] = ofr.get<Value>(); 719 } else { 720 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( 721 extractOp.getLoc(), *getConstantIntValue(ofr)); 722 } 723 } 724 if (isa<MemRefType>(xferOp.getSource().getType())) { 725 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 726 newIndices); 727 } else { 728 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 729 extractOp, xferOp.getSource(), newIndices); 730 } 731 } 732 }; 733 734 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) 735 /// to memref.store. 736 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { 737 using OpRewritePattern::OpRewritePattern; 738 739 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 740 PatternRewriter &rewriter) const override { 741 // Must be a scalar write. 742 auto vecType = xferOp.getVectorType(); 743 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) 744 return failure(); 745 // Mask not supported. 746 if (xferOp.getMask()) 747 return failure(); 748 // Map not supported. 749 if (!xferOp.getPermutationMap().isMinorIdentity()) 750 return failure(); 751 // Only float and integer element types are supported. 752 Value scalar; 753 if (vecType.getRank() == 0) { 754 // vector.extract does not support vector<f32> etc., so use 755 // vector.extractelement instead. 756 scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(), 757 xferOp.getVector()); 758 } else { 759 SmallVector<int64_t> pos(vecType.getRank(), 0); 760 scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(), 761 xferOp.getVector(), pos); 762 } 763 // Construct a scalar store. 764 if (isa<MemRefType>(xferOp.getSource().getType())) { 765 rewriter.replaceOpWithNewOp<memref::StoreOp>( 766 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 767 } else { 768 rewriter.replaceOpWithNewOp<tensor::InsertOp>( 769 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 770 } 771 return success(); 772 } 773 }; 774 775 } // namespace 776 777 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, 778 Operation *rootOp) { 779 TransferOptimization opt(rewriter, rootOp); 780 // Run store to load forwarding first since it can expose more dead store 781 // opportunity. 782 rootOp->walk([&](vector::TransferReadOp read) { 783 if (isa<MemRefType>(read.getShapedType())) 784 opt.storeToLoadForwarding(read); 785 }); 786 opt.removeDeadOp(); 787 rootOp->walk([&](vector::TransferWriteOp write) { 788 if (isa<MemRefType>(write.getShapedType())) 789 opt.deadStoreOp(write); 790 }); 791 opt.removeDeadOp(); 792 } 793 794 void mlir::vector::populateScalarVectorTransferLoweringPatterns( 795 RewritePatternSet &patterns, PatternBenefit benefit, 796 bool allowMultipleUses) { 797 patterns.add<RewriteScalarExtractElementOfTransferRead, 798 RewriteScalarExtractOfTransferRead>(patterns.getContext(), 799 benefit, allowMultipleUses); 800 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit); 801 } 802 803 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 804 RewritePatternSet &patterns, PatternBenefit benefit) { 805 patterns 806 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 807 patterns.getContext(), benefit); 808 populateShapeCastFoldingPatterns(patterns); 809 } 810 811 void mlir::vector::populateFlattenVectorTransferPatterns( 812 RewritePatternSet &patterns, PatternBenefit benefit) { 813 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 814 FlattenContiguousRowMajorTransferWritePattern>( 815 patterns.getContext(), benefit); 816 populateShapeCastFoldingPatterns(patterns, benefit); 817 } 818