1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 20 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 21 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/Support/Debug.h" 28 29 #define DEBUG_TYPE "vector-transfer-opt" 30 31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 32 33 using namespace mlir; 34 35 /// Return the ancestor op in the region or nullptr if the region is not 36 /// an ancestor of the op. 37 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 38 for (; op != nullptr && op->getParentRegion() != region; 39 op = op->getParentOp()) 40 ; 41 return op; 42 } 43 44 namespace { 45 46 class TransferOptimization { 47 public: 48 TransferOptimization(RewriterBase &rewriter, Operation *op) 49 : rewriter(rewriter), dominators(op), postDominators(op) {} 50 void deadStoreOp(vector::TransferWriteOp); 51 void storeToLoadForwarding(vector::TransferReadOp); 52 void removeDeadOp() { 53 for (Operation *op : opToErase) 54 rewriter.eraseOp(op); 55 opToErase.clear(); 56 } 57 58 private: 59 RewriterBase &rewriter; 60 bool isReachable(Operation *start, Operation *dest); 61 DominanceInfo dominators; 62 PostDominanceInfo postDominators; 63 std::vector<Operation *> opToErase; 64 }; 65 66 } // namespace 67 /// Return true if there is a path from start operation to dest operation, 68 /// otherwise return false. The operations have to be in the same region. 69 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 70 assert(start->getParentRegion() == dest->getParentRegion() && 71 "This function only works for ops i the same region"); 72 // Simple case where the start op dominate the destination. 73 if (dominators.dominates(start, dest)) 74 return true; 75 Block *startBlock = start->getBlock(); 76 Block *destBlock = dest->getBlock(); 77 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 78 startBlock->succ_end()); 79 SmallPtrSet<Block *, 32> visited; 80 while (!worklist.empty()) { 81 Block *bb = worklist.pop_back_val(); 82 if (!visited.insert(bb).second) 83 continue; 84 if (dominators.dominates(bb, destBlock)) 85 return true; 86 worklist.append(bb->succ_begin(), bb->succ_end()); 87 } 88 return false; 89 } 90 91 /// For transfer_write to overwrite fully another transfer_write must: 92 /// 1. Access the same memref with the same indices and vector type. 93 /// 2. Post-dominate the other transfer_write operation. 94 /// If several candidates are available, one must be post-dominated by all the 95 /// others since they are all post-dominating the same transfer_write. We only 96 /// consider the transfer_write post-dominated by all the other candidates as 97 /// this will be the first transfer_write executed after the potentially dead 98 /// transfer_write. 99 /// If we found such an overwriting transfer_write we know that the original 100 /// transfer_write is dead if all reads that can be reached from the potentially 101 /// dead transfer_write are dominated by the overwriting transfer_write. 102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 103 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 104 << "\n"); 105 llvm::SmallVector<Operation *, 8> blockingAccesses; 106 Operation *firstOverwriteCandidate = nullptr; 107 Value source = write.getSource(); 108 // Skip subview ops. 109 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 110 source = subView.getSource(); 111 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 112 source.getUsers().end()); 113 llvm::SmallDenseSet<Operation *, 32> processed; 114 while (!users.empty()) { 115 Operation *user = users.pop_back_val(); 116 // If the user has already been processed skip. 117 if (!processed.insert(user).second) 118 continue; 119 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 120 users.append(subView->getUsers().begin(), subView->getUsers().end()); 121 continue; 122 } 123 if (isMemoryEffectFree(user)) 124 continue; 125 if (user == write.getOperation()) 126 continue; 127 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 128 // Check candidate that can override the store. 129 if (write.getSource() == nextWrite.getSource() && 130 checkSameValueWAW(nextWrite, write) && 131 postDominators.postDominates(nextWrite, write)) { 132 if (firstOverwriteCandidate == nullptr || 133 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 134 firstOverwriteCandidate = nextWrite; 135 else 136 assert( 137 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 138 continue; 139 } 140 } 141 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { 142 // Don't need to consider disjoint accesses. 143 if (vector::isDisjointTransferSet( 144 cast<VectorTransferOpInterface>(write.getOperation()), 145 cast<VectorTransferOpInterface>(transferOp.getOperation()), 146 /*testDynamicValueUsingBounds=*/true)) 147 continue; 148 } 149 blockingAccesses.push_back(user); 150 } 151 if (firstOverwriteCandidate == nullptr) 152 return; 153 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 154 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 155 assert(writeAncestor && 156 "write op should be recursively part of the top region"); 157 158 for (Operation *access : blockingAccesses) { 159 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access); 160 // TODO: if the access and write have the same ancestor we could recurse in 161 // the region to know if the access is reachable with more precision. 162 if (accessAncestor == nullptr || 163 !isReachable(writeAncestor, accessAncestor)) 164 continue; 165 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { 166 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " 167 << *accessAncestor << "\n"); 168 return; 169 } 170 } 171 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 172 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 173 opToErase.push_back(write.getOperation()); 174 } 175 176 /// A transfer_write candidate to storeToLoad forwarding must: 177 /// 1. Access the same memref with the same indices and vector type as the 178 /// transfer_read. 179 /// 2. Dominate the transfer_read operation. 180 /// If several candidates are available, one must be dominated by all the others 181 /// since they are all dominating the same transfer_read. We only consider the 182 /// transfer_write dominated by all the other candidates as this will be the 183 /// last transfer_write executed before the transfer_read. 184 /// If we found such a candidate we can do the forwarding if all the other 185 /// potentially aliasing ops that may reach the transfer_read are post-dominated 186 /// by the transfer_write. 187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 188 if (read.hasOutOfBoundsDim()) 189 return; 190 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 191 << "\n"); 192 SmallVector<Operation *, 8> blockingWrites; 193 vector::TransferWriteOp lastwrite = nullptr; 194 Value source = read.getSource(); 195 // Skip subview ops. 196 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 197 source = subView.getSource(); 198 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 199 source.getUsers().end()); 200 llvm::SmallDenseSet<Operation *, 32> processed; 201 while (!users.empty()) { 202 Operation *user = users.pop_back_val(); 203 // If the user has already been processed skip. 204 if (!processed.insert(user).second) 205 continue; 206 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 207 users.append(subView->getUsers().begin(), subView->getUsers().end()); 208 continue; 209 } 210 if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) { 211 users.append(collapsed->getUsers().begin(), collapsed->getUsers().end()); 212 continue; 213 } 214 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) 215 continue; 216 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 217 // If there is a write, but we can prove that it is disjoint we can ignore 218 // the write. 219 if (vector::isDisjointTransferSet( 220 cast<VectorTransferOpInterface>(write.getOperation()), 221 cast<VectorTransferOpInterface>(read.getOperation()), 222 /*testDynamicValueUsingBounds=*/true)) 223 continue; 224 if (write.getSource() == read.getSource() && 225 dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 226 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 227 lastwrite = write; 228 else 229 assert(dominators.dominates(write, lastwrite)); 230 continue; 231 } 232 } 233 blockingWrites.push_back(user); 234 } 235 236 if (lastwrite == nullptr) 237 return; 238 239 Region *topRegion = lastwrite->getParentRegion(); 240 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 241 assert(readAncestor && 242 "read op should be recursively part of the top region"); 243 244 for (Operation *write : blockingWrites) { 245 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 246 // TODO: if the store and read have the same ancestor we could recurse in 247 // the region to know if the read is reachable with more precision. 248 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 249 continue; 250 if (!postDominators.postDominates(lastwrite, write)) { 251 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 252 << *write << "\n"); 253 return; 254 } 255 } 256 257 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 258 << " to: " << *read.getOperation() << "\n"); 259 read.replaceAllUsesWith(lastwrite.getVector()); 260 opToErase.push_back(read.getOperation()); 261 } 262 263 /// Converts OpFoldResults to int64_t shape without unit dims. 264 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { 265 SmallVector<int64_t> reducedShape; 266 for (const auto size : mixedSizes) { 267 if (llvm::dyn_cast_if_present<Value>(size)) { 268 reducedShape.push_back(ShapedType::kDynamic); 269 continue; 270 } 271 272 auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue(); 273 if (value == 1) 274 continue; 275 reducedShape.push_back(value.getSExtValue()); 276 } 277 return reducedShape; 278 } 279 280 /// Drops unit dimensions from the input MemRefType. 281 static MemRefType dropUnitDims(MemRefType inputType, 282 ArrayRef<OpFoldResult> offsets, 283 ArrayRef<OpFoldResult> sizes, 284 ArrayRef<OpFoldResult> strides) { 285 auto targetShape = getReducedShape(sizes); 286 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 287 targetShape, inputType, offsets, sizes, strides); 288 return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType)); 289 } 290 291 /// Creates a rank-reducing memref.subview op that drops unit dims from its 292 /// input. Or just returns the input if it was already without unit dims. 293 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 294 mlir::Location loc, 295 Value input) { 296 MemRefType inputType = cast<MemRefType>(input.getType()); 297 SmallVector<OpFoldResult> offsets(inputType.getRank(), 298 rewriter.getIndexAttr(0)); 299 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input); 300 SmallVector<OpFoldResult> strides(inputType.getRank(), 301 rewriter.getIndexAttr(1)); 302 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); 303 304 if (canonicalizeStridedLayout(resultType) == 305 canonicalizeStridedLayout(inputType)) 306 return input; 307 return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets, 308 sizes, strides); 309 } 310 311 /// Returns the number of dims that aren't unit dims. 312 static int getReducedRank(ArrayRef<int64_t> shape) { 313 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 314 } 315 316 /// Trims non-scalable one dimensions from `oldType` and returns the result 317 /// type. 318 static VectorType trimNonScalableUnitDims(VectorType oldType) { 319 SmallVector<int64_t> newShape; 320 SmallVector<bool> newScalableDims; 321 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { 322 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) 323 continue; 324 newShape.push_back(dimSize); 325 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); 326 } 327 return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 328 } 329 330 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. 331 static FailureOr<Value> 332 createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, 333 vector::CreateMaskOp op) { 334 auto type = op.getType(); 335 VectorType reducedType = trimNonScalableUnitDims(type); 336 if (reducedType.getRank() == type.getRank()) 337 return failure(); 338 339 SmallVector<Value> reducedOperands; 340 for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( 341 type.getShape(), type.getScalableDims(), op.getOperands())) { 342 if (dim == 1 && !dimIsScalable) { 343 // If the mask for the unit dim is not a constant of 1, do nothing. 344 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); 345 if (!constant || (constant.value() != 1)) 346 return failure(); 347 continue; 348 } 349 reducedOperands.push_back(operand); 350 } 351 return rewriter 352 .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) 353 .getResult(); 354 } 355 356 namespace { 357 358 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by 359 /// inserting a memref.subview dropping those unit dims. The vector shapes are 360 /// also reduced accordingly. 361 class TransferReadDropUnitDimsPattern 362 : public OpRewritePattern<vector::TransferReadOp> { 363 using OpRewritePattern::OpRewritePattern; 364 365 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 366 PatternRewriter &rewriter) const override { 367 auto loc = transferReadOp.getLoc(); 368 Value vector = transferReadOp.getVector(); 369 VectorType vectorType = cast<VectorType>(vector.getType()); 370 Value source = transferReadOp.getSource(); 371 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 372 // TODO: support tensor types. 373 if (!sourceType) 374 return failure(); 375 // TODO: generalize this pattern, relax the requirements here. 376 if (transferReadOp.hasOutOfBoundsDim()) 377 return failure(); 378 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 379 return failure(); 380 // Check if the source shape can be further reduced. 381 int reducedRank = getReducedRank(sourceType.getShape()); 382 if (reducedRank == sourceType.getRank()) 383 return failure(); 384 // Check if the reduced vector shape matches the reduced source shape. 385 // Otherwise, this case is not supported yet. 386 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 387 if (reducedRank != reducedVectorType.getRank()) 388 return failure(); 389 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { 390 return getConstantIntValue(v) != static_cast<int64_t>(0); 391 })) 392 return failure(); 393 394 Value maskOp = transferReadOp.getMask(); 395 if (maskOp) { 396 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 397 if (!createMaskOp) 398 return rewriter.notifyMatchFailure( 399 transferReadOp, "unsupported mask op, only 'vector.create_mask' is " 400 "currently supported"); 401 FailureOr<Value> rankReducedCreateMask = 402 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 403 if (failed(rankReducedCreateMask)) 404 return failure(); 405 maskOp = *rankReducedCreateMask; 406 } 407 408 Value reducedShapeSource = 409 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 410 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 411 SmallVector<Value> zeros(reducedRank, c0); 412 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 413 SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 414 auto newTransferReadOp = rewriter.create<vector::TransferReadOp>( 415 loc, reducedVectorType, reducedShapeSource, zeros, identityMap, 416 transferReadOp.getPadding(), maskOp, 417 rewriter.getBoolArrayAttr(inBounds)); 418 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 419 loc, vectorType, newTransferReadOp); 420 rewriter.replaceOp(transferReadOp, shapeCast); 421 422 return success(); 423 } 424 }; 425 426 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) 427 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The 428 /// vector shapes are also reduced accordingly. 429 class TransferWriteDropUnitDimsPattern 430 : public OpRewritePattern<vector::TransferWriteOp> { 431 using OpRewritePattern::OpRewritePattern; 432 433 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 434 PatternRewriter &rewriter) const override { 435 auto loc = transferWriteOp.getLoc(); 436 Value vector = transferWriteOp.getVector(); 437 VectorType vectorType = cast<VectorType>(vector.getType()); 438 Value source = transferWriteOp.getSource(); 439 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 440 // TODO: support tensor type. 441 if (!sourceType) 442 return failure(); 443 // TODO: generalize this pattern, relax the requirements here. 444 if (transferWriteOp.hasOutOfBoundsDim()) 445 return failure(); 446 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 447 return failure(); 448 // Check if the destination shape can be further reduced. 449 int reducedRank = getReducedRank(sourceType.getShape()); 450 if (reducedRank == sourceType.getRank()) 451 return failure(); 452 // Check if the reduced vector shape matches the reduced destination shape. 453 // Otherwise, this case is not supported yet. 454 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 455 if (reducedRank != reducedVectorType.getRank()) 456 return failure(); 457 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { 458 return getConstantIntValue(v) != static_cast<int64_t>(0); 459 })) 460 return failure(); 461 462 Value maskOp = transferWriteOp.getMask(); 463 if (maskOp) { 464 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 465 if (!createMaskOp) 466 return rewriter.notifyMatchFailure( 467 transferWriteOp, 468 "unsupported mask op, only 'vector.create_mask' is " 469 "currently supported"); 470 FailureOr<Value> rankReducedCreateMask = 471 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 472 if (failed(rankReducedCreateMask)) 473 return failure(); 474 maskOp = *rankReducedCreateMask; 475 } 476 Value reducedShapeSource = 477 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 478 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 479 SmallVector<Value> zeros(reducedRank, c0); 480 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 481 SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 482 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 483 loc, reducedVectorType, vector); 484 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 485 transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros, 486 identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds)); 487 488 return success(); 489 } 490 }; 491 492 } // namespace 493 494 /// Return true if the memref type has its inner dimension matching the given 495 /// shape. Otherwise return false. 496 static int64_t hasMatchingInnerContigousShape(MemRefType memrefType, 497 ArrayRef<int64_t> targetShape) { 498 auto shape = memrefType.getShape(); 499 SmallVector<int64_t> strides; 500 int64_t offset; 501 if (!succeeded(getStridesAndOffset(memrefType, strides, offset))) 502 return false; 503 if (strides.back() != 1) 504 return false; 505 strides.pop_back(); 506 int64_t flatDim = 1; 507 for (auto [targetDim, memrefDim, memrefStride] : 508 llvm::reverse(llvm::zip(targetShape, shape, strides))) { 509 flatDim *= memrefDim; 510 if (flatDim != memrefStride || targetDim != memrefDim) 511 return false; 512 } 513 return true; 514 } 515 516 /// Creates a memref.collapse_shape collapsing all inner dimensions of the 517 /// input starting at `firstDimToCollapse`. 518 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, 519 Value input, int64_t firstDimToCollapse) { 520 ShapedType inputType = cast<ShapedType>(input.getType()); 521 if (inputType.getRank() == 1) 522 return input; 523 SmallVector<ReassociationIndices> reassociation; 524 for (int64_t i = 0; i < firstDimToCollapse; ++i) 525 reassociation.push_back(ReassociationIndices{i}); 526 ReassociationIndices collapsedIndices; 527 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) 528 collapsedIndices.push_back(i); 529 reassociation.push_back(collapsedIndices); 530 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); 531 } 532 533 /// Checks that the indices corresponding to dimensions starting at 534 /// `firstDimToCollapse` are constant 0, and writes to `outIndices` 535 /// the truncated indices where `firstDimToCollapse` is now the innermost dim. 536 static LogicalResult 537 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, 538 SmallVector<Value> &outIndices) { 539 int64_t rank = indices.size(); 540 if (firstDimToCollapse >= rank) 541 return failure(); 542 for (int64_t i = firstDimToCollapse; i < rank; ++i) { 543 std::optional<int64_t> cst = getConstantIntValue(indices[i]); 544 if (!cst || cst.value() != 0) 545 return failure(); 546 } 547 outIndices = indices; 548 outIndices.resize(firstDimToCollapse + 1); 549 return success(); 550 } 551 552 namespace { 553 554 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 555 /// memref.collapse_shape on the source so that the resulting 556 /// vector.transfer_read has a 1D source. Requires the source shape to be 557 /// already reduced i.e. without unit dims. 558 class FlattenContiguousRowMajorTransferReadPattern 559 : public OpRewritePattern<vector::TransferReadOp> { 560 using OpRewritePattern::OpRewritePattern; 561 562 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 563 PatternRewriter &rewriter) const override { 564 auto loc = transferReadOp.getLoc(); 565 Value vector = transferReadOp.getVector(); 566 VectorType vectorType = cast<VectorType>(vector.getType()); 567 Value source = transferReadOp.getSource(); 568 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 569 // Contiguity check is valid on tensors only. 570 if (!sourceType) 571 return failure(); 572 if (vectorType.getRank() <= 1) 573 // Already 0D/1D, nothing to do. 574 return failure(); 575 if (!hasMatchingInnerContigousShape( 576 sourceType, 577 vectorType.getShape().take_back(vectorType.getRank() - 1))) 578 return failure(); 579 int64_t firstContiguousInnerDim = 580 sourceType.getRank() - vectorType.getRank(); 581 // TODO: generalize this pattern, relax the requirements here. 582 if (transferReadOp.hasOutOfBoundsDim()) 583 return failure(); 584 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 585 return failure(); 586 if (transferReadOp.getMask()) 587 return failure(); 588 SmallVector<Value> collapsedIndices; 589 if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), 590 firstContiguousInnerDim, 591 collapsedIndices))) 592 return failure(); 593 Value collapsedSource = 594 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 595 MemRefType collapsedSourceType = 596 dyn_cast<MemRefType>(collapsedSource.getType()); 597 int64_t collapsedRank = collapsedSourceType.getRank(); 598 assert(collapsedRank == firstContiguousInnerDim + 1); 599 SmallVector<AffineExpr, 1> dimExprs{ 600 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 601 auto collapsedMap = 602 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 603 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 604 vectorType.getElementType()); 605 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( 606 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); 607 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 608 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 609 transferReadOp, cast<VectorType>(vector.getType()), flatRead); 610 return success(); 611 } 612 }; 613 614 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 615 /// memref.collapse_shape on the source so that the resulting 616 /// vector.transfer_write has a 1D source. Requires the source shape to be 617 /// already reduced i.e. without unit dims. 618 class FlattenContiguousRowMajorTransferWritePattern 619 : public OpRewritePattern<vector::TransferWriteOp> { 620 using OpRewritePattern::OpRewritePattern; 621 622 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 623 PatternRewriter &rewriter) const override { 624 auto loc = transferWriteOp.getLoc(); 625 Value vector = transferWriteOp.getVector(); 626 VectorType vectorType = cast<VectorType>(vector.getType()); 627 Value source = transferWriteOp.getSource(); 628 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 629 // Contiguity check is valid on tensors only. 630 if (!sourceType) 631 return failure(); 632 if (vectorType.getRank() <= 1) 633 // Already 0D/1D, nothing to do. 634 return failure(); 635 if (!hasMatchingInnerContigousShape( 636 sourceType, 637 vectorType.getShape().take_back(vectorType.getRank() - 1))) 638 return failure(); 639 int64_t firstContiguousInnerDim = 640 sourceType.getRank() - vectorType.getRank(); 641 // TODO: generalize this pattern, relax the requirements here. 642 if (transferWriteOp.hasOutOfBoundsDim()) 643 return failure(); 644 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 645 return failure(); 646 if (transferWriteOp.getMask()) 647 return failure(); 648 SmallVector<Value> collapsedIndices; 649 if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(), 650 firstContiguousInnerDim, 651 collapsedIndices))) 652 return failure(); 653 Value collapsedSource = 654 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 655 MemRefType collapsedSourceType = 656 cast<MemRefType>(collapsedSource.getType()); 657 int64_t collapsedRank = collapsedSourceType.getRank(); 658 assert(collapsedRank == firstContiguousInnerDim + 1); 659 SmallVector<AffineExpr, 1> dimExprs{ 660 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 661 auto collapsedMap = 662 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 663 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 664 vectorType.getElementType()); 665 Value flatVector = 666 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); 667 vector::TransferWriteOp flatWrite = 668 rewriter.create<vector::TransferWriteOp>( 669 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); 670 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 671 rewriter.eraseOp(transferWriteOp); 672 return success(); 673 } 674 }; 675 676 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` 677 /// to `memref.load` patterns. The `match` method is shared for both 678 /// `vector.extract` and `vector.extract_element`. 679 template <class VectorExtractOp> 680 class RewriteScalarExtractOfTransferReadBase 681 : public OpRewritePattern<VectorExtractOp> { 682 using Base = OpRewritePattern<VectorExtractOp>; 683 684 public: 685 RewriteScalarExtractOfTransferReadBase(MLIRContext *context, 686 PatternBenefit benefit, 687 bool allowMultipleUses) 688 : Base::OpRewritePattern(context, benefit), 689 allowMultipleUses(allowMultipleUses) {} 690 691 LogicalResult match(VectorExtractOp extractOp) const override { 692 auto xferOp = 693 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); 694 if (!xferOp) 695 return failure(); 696 // Check that we are extracting a scalar and not a sub-vector. 697 if (isa<VectorType>(extractOp.getResult().getType())) 698 return failure(); 699 // If multiple uses are not allowed, check if xfer has a single use. 700 if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) 701 return failure(); 702 // If multiple uses are allowed, check if all the xfer uses are extract ops. 703 if (allowMultipleUses && 704 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { 705 return isa<vector::ExtractOp, vector::ExtractElementOp>( 706 use.getOwner()); 707 })) 708 return failure(); 709 // Mask not supported. 710 if (xferOp.getMask()) 711 return failure(); 712 // Map not supported. 713 if (!xferOp.getPermutationMap().isMinorIdentity()) 714 return failure(); 715 // Cannot rewrite if the indices may be out of bounds. 716 if (xferOp.hasOutOfBoundsDim()) 717 return failure(); 718 return success(); 719 } 720 721 private: 722 bool allowMultipleUses; 723 }; 724 725 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 726 /// 727 /// All the users of the transfer op must be either `vector.extractelement` or 728 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 729 /// transfer ops with any number of users. Otherwise, rewrite only if the 730 /// extract op is the single user of the transfer op. Rewriting a single 731 /// vector load with multiple scalar loads may negatively affect performance. 732 class RewriteScalarExtractElementOfTransferRead 733 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { 734 using RewriteScalarExtractOfTransferReadBase:: 735 RewriteScalarExtractOfTransferReadBase; 736 737 void rewrite(vector::ExtractElementOp extractOp, 738 PatternRewriter &rewriter) const override { 739 // Construct scalar load. 740 auto loc = extractOp.getLoc(); 741 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 742 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 743 xferOp.getIndices().end()); 744 if (extractOp.getPosition()) { 745 AffineExpr sym0, sym1; 746 bindSymbols(extractOp.getContext(), sym0, sym1); 747 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 748 rewriter, loc, sym0 + sym1, 749 {newIndices[newIndices.size() - 1], extractOp.getPosition()}); 750 if (ofr.is<Value>()) { 751 newIndices[newIndices.size() - 1] = ofr.get<Value>(); 752 } else { 753 newIndices[newIndices.size() - 1] = 754 rewriter.create<arith::ConstantIndexOp>(loc, 755 *getConstantIntValue(ofr)); 756 } 757 } 758 if (isa<MemRefType>(xferOp.getSource().getType())) { 759 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 760 newIndices); 761 } else { 762 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 763 extractOp, xferOp.getSource(), newIndices); 764 } 765 } 766 }; 767 768 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 769 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. 770 /// 771 /// All the users of the transfer op must be either `vector.extractelement` or 772 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 773 /// transfer ops with any number of users. Otherwise, rewrite only if the 774 /// extract op is the single user of the transfer op. Rewriting a single 775 /// vector load with multiple scalar loads may negatively affect performance. 776 class RewriteScalarExtractOfTransferRead 777 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { 778 using RewriteScalarExtractOfTransferReadBase:: 779 RewriteScalarExtractOfTransferReadBase; 780 781 void rewrite(vector::ExtractOp extractOp, 782 PatternRewriter &rewriter) const override { 783 // Construct scalar load. 784 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 785 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 786 xferOp.getIndices().end()); 787 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 788 assert(pos.is<Attribute>() && "Unexpected non-constant index"); 789 int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt(); 790 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; 791 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 792 rewriter, extractOp.getLoc(), 793 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); 794 if (ofr.is<Value>()) { 795 newIndices[idx] = ofr.get<Value>(); 796 } else { 797 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( 798 extractOp.getLoc(), *getConstantIntValue(ofr)); 799 } 800 } 801 if (isa<MemRefType>(xferOp.getSource().getType())) { 802 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 803 newIndices); 804 } else { 805 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 806 extractOp, xferOp.getSource(), newIndices); 807 } 808 } 809 }; 810 811 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) 812 /// to memref.store. 813 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { 814 using OpRewritePattern::OpRewritePattern; 815 816 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 817 PatternRewriter &rewriter) const override { 818 // Must be a scalar write. 819 auto vecType = xferOp.getVectorType(); 820 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) 821 return failure(); 822 // Mask not supported. 823 if (xferOp.getMask()) 824 return failure(); 825 // Map not supported. 826 if (!xferOp.getPermutationMap().isMinorIdentity()) 827 return failure(); 828 // Only float and integer element types are supported. 829 Value scalar; 830 if (vecType.getRank() == 0) { 831 // vector.extract does not support vector<f32> etc., so use 832 // vector.extractelement instead. 833 scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(), 834 xferOp.getVector()); 835 } else { 836 SmallVector<int64_t> pos(vecType.getRank(), 0); 837 scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(), 838 xferOp.getVector(), pos); 839 } 840 // Construct a scalar store. 841 if (isa<MemRefType>(xferOp.getSource().getType())) { 842 rewriter.replaceOpWithNewOp<memref::StoreOp>( 843 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 844 } else { 845 rewriter.replaceOpWithNewOp<tensor::InsertOp>( 846 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 847 } 848 return success(); 849 } 850 }; 851 852 } // namespace 853 854 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, 855 Operation *rootOp) { 856 TransferOptimization opt(rewriter, rootOp); 857 // Run store to load forwarding first since it can expose more dead store 858 // opportunity. 859 rootOp->walk([&](vector::TransferReadOp read) { 860 if (isa<MemRefType>(read.getShapedType())) 861 opt.storeToLoadForwarding(read); 862 }); 863 opt.removeDeadOp(); 864 rootOp->walk([&](vector::TransferWriteOp write) { 865 if (isa<MemRefType>(write.getShapedType())) 866 opt.deadStoreOp(write); 867 }); 868 opt.removeDeadOp(); 869 } 870 871 void mlir::vector::populateScalarVectorTransferLoweringPatterns( 872 RewritePatternSet &patterns, PatternBenefit benefit, 873 bool allowMultipleUses) { 874 patterns.add<RewriteScalarExtractElementOfTransferRead, 875 RewriteScalarExtractOfTransferRead>(patterns.getContext(), 876 benefit, allowMultipleUses); 877 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit); 878 } 879 880 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 881 RewritePatternSet &patterns, PatternBenefit benefit) { 882 patterns 883 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 884 patterns.getContext(), benefit); 885 populateShapeCastFoldingPatterns(patterns); 886 } 887 888 void mlir::vector::populateFlattenVectorTransferPatterns( 889 RewritePatternSet &patterns, PatternBenefit benefit) { 890 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 891 FlattenContiguousRowMajorTransferWritePattern>( 892 patterns.getContext(), benefit); 893 populateShapeCastFoldingPatterns(patterns, benefit); 894 } 895