1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/Tensor/IR/Tensor.h" 18 #include "mlir/Dialect/Utils/IndexingUtils.h" 19 #include "mlir/Dialect/Vector/IR/VectorOps.h" 20 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 21 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 22 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/Interfaces/SideEffectInterfaces.h" 25 #include "llvm/ADT/STLExtras.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/Support/Debug.h" 28 29 #define DEBUG_TYPE "vector-transfer-opt" 30 31 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 32 33 using namespace mlir; 34 35 /// Return the ancestor op in the region or nullptr if the region is not 36 /// an ancestor of the op. 37 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 38 for (; op != nullptr && op->getParentRegion() != region; 39 op = op->getParentOp()) 40 ; 41 return op; 42 } 43 44 namespace { 45 46 class TransferOptimization { 47 public: 48 TransferOptimization(RewriterBase &rewriter, Operation *op) 49 : rewriter(rewriter), dominators(op), postDominators(op) {} 50 void deadStoreOp(vector::TransferWriteOp); 51 void storeToLoadForwarding(vector::TransferReadOp); 52 void removeDeadOp() { 53 for (Operation *op : opToErase) 54 rewriter.eraseOp(op); 55 opToErase.clear(); 56 } 57 58 private: 59 RewriterBase &rewriter; 60 bool isReachable(Operation *start, Operation *dest); 61 DominanceInfo dominators; 62 PostDominanceInfo postDominators; 63 std::vector<Operation *> opToErase; 64 }; 65 66 } // namespace 67 /// Return true if there is a path from start operation to dest operation, 68 /// otherwise return false. The operations have to be in the same region. 69 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 70 assert(start->getParentRegion() == dest->getParentRegion() && 71 "This function only works for ops i the same region"); 72 // Simple case where the start op dominate the destination. 73 if (dominators.dominates(start, dest)) 74 return true; 75 Block *startBlock = start->getBlock(); 76 Block *destBlock = dest->getBlock(); 77 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 78 startBlock->succ_end()); 79 SmallPtrSet<Block *, 32> visited; 80 while (!worklist.empty()) { 81 Block *bb = worklist.pop_back_val(); 82 if (!visited.insert(bb).second) 83 continue; 84 if (dominators.dominates(bb, destBlock)) 85 return true; 86 worklist.append(bb->succ_begin(), bb->succ_end()); 87 } 88 return false; 89 } 90 91 /// For transfer_write to overwrite fully another transfer_write must: 92 /// 1. Access the same memref with the same indices and vector type. 93 /// 2. Post-dominate the other transfer_write operation. 94 /// If several candidates are available, one must be post-dominated by all the 95 /// others since they are all post-dominating the same transfer_write. We only 96 /// consider the transfer_write post-dominated by all the other candidates as 97 /// this will be the first transfer_write executed after the potentially dead 98 /// transfer_write. 99 /// If we found such an overwriting transfer_write we know that the original 100 /// transfer_write is dead if all reads that can be reached from the potentially 101 /// dead transfer_write are dominated by the overwriting transfer_write. 102 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 103 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 104 << "\n"); 105 llvm::SmallVector<Operation *, 8> blockingAccesses; 106 Operation *firstOverwriteCandidate = nullptr; 107 Value source = write.getSource(); 108 // Skip subview ops. 109 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 110 source = subView.getSource(); 111 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 112 source.getUsers().end()); 113 llvm::SmallDenseSet<Operation *, 32> processed; 114 while (!users.empty()) { 115 Operation *user = users.pop_back_val(); 116 // If the user has already been processed skip. 117 if (!processed.insert(user).second) 118 continue; 119 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 120 users.append(subView->getUsers().begin(), subView->getUsers().end()); 121 continue; 122 } 123 if (isMemoryEffectFree(user)) 124 continue; 125 if (user == write.getOperation()) 126 continue; 127 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 128 // Check candidate that can override the store. 129 if (write.getSource() == nextWrite.getSource() && 130 checkSameValueWAW(nextWrite, write) && 131 postDominators.postDominates(nextWrite, write)) { 132 if (firstOverwriteCandidate == nullptr || 133 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 134 firstOverwriteCandidate = nextWrite; 135 else 136 assert( 137 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 138 continue; 139 } 140 } 141 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { 142 // Don't need to consider disjoint accesses. 143 if (vector::isDisjointTransferSet( 144 cast<VectorTransferOpInterface>(write.getOperation()), 145 cast<VectorTransferOpInterface>(transferOp.getOperation()), 146 /*testDynamicValueUsingBounds=*/true)) 147 continue; 148 } 149 blockingAccesses.push_back(user); 150 } 151 if (firstOverwriteCandidate == nullptr) 152 return; 153 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 154 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 155 assert(writeAncestor && 156 "write op should be recursively part of the top region"); 157 158 for (Operation *access : blockingAccesses) { 159 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access); 160 // TODO: if the access and write have the same ancestor we could recurse in 161 // the region to know if the access is reachable with more precision. 162 if (accessAncestor == nullptr || 163 !isReachable(writeAncestor, accessAncestor)) 164 continue; 165 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { 166 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " 167 << *accessAncestor << "\n"); 168 return; 169 } 170 } 171 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 172 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 173 opToErase.push_back(write.getOperation()); 174 } 175 176 /// A transfer_write candidate to storeToLoad forwarding must: 177 /// 1. Access the same memref with the same indices and vector type as the 178 /// transfer_read. 179 /// 2. Dominate the transfer_read operation. 180 /// If several candidates are available, one must be dominated by all the others 181 /// since they are all dominating the same transfer_read. We only consider the 182 /// transfer_write dominated by all the other candidates as this will be the 183 /// last transfer_write executed before the transfer_read. 184 /// If we found such a candidate we can do the forwarding if all the other 185 /// potentially aliasing ops that may reach the transfer_read are post-dominated 186 /// by the transfer_write. 187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 188 if (read.hasOutOfBoundsDim()) 189 return; 190 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 191 << "\n"); 192 SmallVector<Operation *, 8> blockingWrites; 193 vector::TransferWriteOp lastwrite = nullptr; 194 Value source = read.getSource(); 195 // Skip subview ops. 196 while (auto subView = source.getDefiningOp<memref::SubViewOp>()) 197 source = subView.getSource(); 198 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 199 source.getUsers().end()); 200 llvm::SmallDenseSet<Operation *, 32> processed; 201 while (!users.empty()) { 202 Operation *user = users.pop_back_val(); 203 // If the user has already been processed skip. 204 if (!processed.insert(user).second) 205 continue; 206 if (auto subView = dyn_cast<memref::SubViewOp>(user)) { 207 users.append(subView->getUsers().begin(), subView->getUsers().end()); 208 continue; 209 } 210 if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) { 211 users.append(collapsed->getUsers().begin(), collapsed->getUsers().end()); 212 continue; 213 } 214 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) 215 continue; 216 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 217 // If there is a write, but we can prove that it is disjoint we can ignore 218 // the write. 219 if (vector::isDisjointTransferSet( 220 cast<VectorTransferOpInterface>(write.getOperation()), 221 cast<VectorTransferOpInterface>(read.getOperation()), 222 /*testDynamicValueUsingBounds=*/true)) 223 continue; 224 if (write.getSource() == read.getSource() && 225 dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 226 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 227 lastwrite = write; 228 else 229 assert(dominators.dominates(write, lastwrite)); 230 continue; 231 } 232 } 233 blockingWrites.push_back(user); 234 } 235 236 if (lastwrite == nullptr) 237 return; 238 239 Region *topRegion = lastwrite->getParentRegion(); 240 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 241 assert(readAncestor && 242 "read op should be recursively part of the top region"); 243 244 for (Operation *write : blockingWrites) { 245 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 246 // TODO: if the store and read have the same ancestor we could recurse in 247 // the region to know if the read is reachable with more precision. 248 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 249 continue; 250 if (!postDominators.postDominates(lastwrite, write)) { 251 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 252 << *write << "\n"); 253 return; 254 } 255 } 256 257 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 258 << " to: " << *read.getOperation() << "\n"); 259 read.replaceAllUsesWith(lastwrite.getVector()); 260 opToErase.push_back(read.getOperation()); 261 } 262 263 /// Converts OpFoldResults to int64_t shape without unit dims. 264 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { 265 SmallVector<int64_t> reducedShape; 266 for (const auto size : mixedSizes) { 267 if (llvm::dyn_cast_if_present<Value>(size)) { 268 reducedShape.push_back(ShapedType::kDynamic); 269 continue; 270 } 271 272 auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue(); 273 if (value == 1) 274 continue; 275 reducedShape.push_back(value.getSExtValue()); 276 } 277 return reducedShape; 278 } 279 280 /// Drops unit dimensions from the input MemRefType. 281 static MemRefType dropUnitDims(MemRefType inputType, 282 ArrayRef<OpFoldResult> offsets, 283 ArrayRef<OpFoldResult> sizes, 284 ArrayRef<OpFoldResult> strides) { 285 auto targetShape = getReducedShape(sizes); 286 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 287 targetShape, inputType, offsets, sizes, strides); 288 return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType)); 289 } 290 291 /// Creates a rank-reducing memref.subview op that drops unit dims from its 292 /// input. Or just returns the input if it was already without unit dims. 293 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 294 mlir::Location loc, 295 Value input) { 296 MemRefType inputType = cast<MemRefType>(input.getType()); 297 SmallVector<OpFoldResult> offsets(inputType.getRank(), 298 rewriter.getIndexAttr(0)); 299 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input); 300 SmallVector<OpFoldResult> strides(inputType.getRank(), 301 rewriter.getIndexAttr(1)); 302 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); 303 304 if (canonicalizeStridedLayout(resultType) == 305 canonicalizeStridedLayout(inputType)) 306 return input; 307 return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets, 308 sizes, strides); 309 } 310 311 /// Returns the number of dims that aren't unit dims. 312 static int getReducedRank(ArrayRef<int64_t> shape) { 313 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 314 } 315 316 /// Trims non-scalable one dimensions from `oldType` and returns the result 317 /// type. 318 static VectorType trimNonScalableUnitDims(VectorType oldType) { 319 SmallVector<int64_t> newShape; 320 SmallVector<bool> newScalableDims; 321 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { 322 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) 323 continue; 324 newShape.push_back(dimSize); 325 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); 326 } 327 return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 328 } 329 330 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. 331 static FailureOr<Value> 332 createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, 333 vector::CreateMaskOp op) { 334 auto type = op.getType(); 335 VectorType reducedType = trimNonScalableUnitDims(type); 336 if (reducedType.getRank() == type.getRank()) 337 return failure(); 338 339 SmallVector<Value> reducedOperands; 340 for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( 341 type.getShape(), type.getScalableDims(), op.getOperands())) { 342 if (dim == 1 && !dimIsScalable) { 343 // If the mask for the unit dim is not a constant of 1, do nothing. 344 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); 345 if (!constant || (constant.value() != 1)) 346 return failure(); 347 continue; 348 } 349 reducedOperands.push_back(operand); 350 } 351 return rewriter 352 .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) 353 .getResult(); 354 } 355 356 namespace { 357 358 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by 359 /// inserting a memref.subview dropping those unit dims. The vector shapes are 360 /// also reduced accordingly. 361 class TransferReadDropUnitDimsPattern 362 : public OpRewritePattern<vector::TransferReadOp> { 363 using OpRewritePattern::OpRewritePattern; 364 365 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 366 PatternRewriter &rewriter) const override { 367 auto loc = transferReadOp.getLoc(); 368 Value vector = transferReadOp.getVector(); 369 VectorType vectorType = cast<VectorType>(vector.getType()); 370 Value source = transferReadOp.getSource(); 371 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 372 // TODO: support tensor types. 373 if (!sourceType) 374 return failure(); 375 // TODO: generalize this pattern, relax the requirements here. 376 if (transferReadOp.hasOutOfBoundsDim()) 377 return failure(); 378 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 379 return failure(); 380 // Check if the source shape can be further reduced. 381 int reducedRank = getReducedRank(sourceType.getShape()); 382 if (reducedRank == sourceType.getRank()) 383 return failure(); 384 // Check if the reduced vector shape matches the reduced source shape. 385 // Otherwise, this case is not supported yet. 386 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 387 if (reducedRank != reducedVectorType.getRank()) 388 return failure(); 389 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { 390 return getConstantIntValue(v) != static_cast<int64_t>(0); 391 })) 392 return failure(); 393 394 Value maskOp = transferReadOp.getMask(); 395 if (maskOp) { 396 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 397 if (!createMaskOp) 398 return rewriter.notifyMatchFailure( 399 transferReadOp, "unsupported mask op, only 'vector.create_mask' is " 400 "currently supported"); 401 FailureOr<Value> rankReducedCreateMask = 402 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 403 if (failed(rankReducedCreateMask)) 404 return failure(); 405 maskOp = *rankReducedCreateMask; 406 } 407 408 Value reducedShapeSource = 409 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 410 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 411 SmallVector<Value> zeros(reducedRank, c0); 412 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 413 SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 414 auto newTransferReadOp = rewriter.create<vector::TransferReadOp>( 415 loc, reducedVectorType, reducedShapeSource, zeros, identityMap, 416 transferReadOp.getPadding(), maskOp, 417 rewriter.getBoolArrayAttr(inBounds)); 418 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 419 loc, vectorType, newTransferReadOp); 420 rewriter.replaceOp(transferReadOp, shapeCast); 421 422 return success(); 423 } 424 }; 425 426 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) 427 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The 428 /// vector shapes are also reduced accordingly. 429 class TransferWriteDropUnitDimsPattern 430 : public OpRewritePattern<vector::TransferWriteOp> { 431 using OpRewritePattern::OpRewritePattern; 432 433 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 434 PatternRewriter &rewriter) const override { 435 auto loc = transferWriteOp.getLoc(); 436 Value vector = transferWriteOp.getVector(); 437 VectorType vectorType = cast<VectorType>(vector.getType()); 438 Value source = transferWriteOp.getSource(); 439 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 440 // TODO: support tensor type. 441 if (!sourceType) 442 return failure(); 443 // TODO: generalize this pattern, relax the requirements here. 444 if (transferWriteOp.hasOutOfBoundsDim()) 445 return failure(); 446 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 447 return failure(); 448 // Check if the destination shape can be further reduced. 449 int reducedRank = getReducedRank(sourceType.getShape()); 450 if (reducedRank == sourceType.getRank()) 451 return failure(); 452 // Check if the reduced vector shape matches the reduced destination shape. 453 // Otherwise, this case is not supported yet. 454 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 455 if (reducedRank != reducedVectorType.getRank()) 456 return failure(); 457 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { 458 return getConstantIntValue(v) != static_cast<int64_t>(0); 459 })) 460 return failure(); 461 462 Value maskOp = transferWriteOp.getMask(); 463 if (maskOp) { 464 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 465 if (!createMaskOp) 466 return rewriter.notifyMatchFailure( 467 transferWriteOp, 468 "unsupported mask op, only 'vector.create_mask' is " 469 "currently supported"); 470 FailureOr<Value> rankReducedCreateMask = 471 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 472 if (failed(rankReducedCreateMask)) 473 return failure(); 474 maskOp = *rankReducedCreateMask; 475 } 476 Value reducedShapeSource = 477 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 478 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 479 SmallVector<Value> zeros(reducedRank, c0); 480 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 481 SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 482 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 483 loc, reducedVectorType, vector); 484 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 485 transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros, 486 identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds)); 487 488 return success(); 489 } 490 }; 491 492 } // namespace 493 494 /// Creates a memref.collapse_shape collapsing all inner dimensions of the 495 /// input starting at `firstDimToCollapse`. 496 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, 497 Value input, int64_t firstDimToCollapse) { 498 ShapedType inputType = cast<ShapedType>(input.getType()); 499 if (inputType.getRank() == 1) 500 return input; 501 SmallVector<ReassociationIndices> reassociation; 502 for (int64_t i = 0; i < firstDimToCollapse; ++i) 503 reassociation.push_back(ReassociationIndices{i}); 504 ReassociationIndices collapsedIndices; 505 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) 506 collapsedIndices.push_back(i); 507 reassociation.push_back(collapsedIndices); 508 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); 509 } 510 511 /// Checks that the indices corresponding to dimensions starting at 512 /// `firstDimToCollapse` are constant 0, and writes to `outIndices` 513 /// the truncated indices where `firstDimToCollapse` is now the innermost dim. 514 /// TODO: Extract the logic that writes to outIndices so that this method 515 /// simply checks one pre-condition. 516 static LogicalResult 517 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, 518 SmallVector<Value> &outIndices) { 519 int64_t rank = indices.size(); 520 if (firstDimToCollapse >= rank) 521 return failure(); 522 for (int64_t i = firstDimToCollapse; i < rank; ++i) { 523 std::optional<int64_t> cst = getConstantIntValue(indices[i]); 524 if (!cst || cst.value() != 0) 525 return failure(); 526 } 527 outIndices = indices; 528 outIndices.resize(firstDimToCollapse + 1); 529 return success(); 530 } 531 532 namespace { 533 534 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 535 /// memref.collapse_shape on the source so that the resulting 536 /// vector.transfer_read has a 1D source. Requires the source shape to be 537 /// already reduced i.e. without unit dims. 538 /// If `targetVectorBitwidth` is provided, the flattening will only happen if 539 /// the trailing dimension of the vector read is smaller than the provided 540 /// bitwidth. 541 class FlattenContiguousRowMajorTransferReadPattern 542 : public OpRewritePattern<vector::TransferReadOp> { 543 public: 544 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, 545 unsigned vectorBitwidth, 546 PatternBenefit benefit) 547 : OpRewritePattern<vector::TransferReadOp>(context, benefit), 548 targetVectorBitwidth(vectorBitwidth) {} 549 550 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 551 PatternRewriter &rewriter) const override { 552 auto loc = transferReadOp.getLoc(); 553 Value vector = transferReadOp.getVector(); 554 VectorType vectorType = cast<VectorType>(vector.getType()); 555 auto source = transferReadOp.getSource(); 556 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 557 558 // 0. Check pre-conditions 559 // Contiguity check is valid on tensors only. 560 if (!sourceType) 561 return failure(); 562 // If this is already 0D/1D, there's nothing to do. 563 if (vectorType.getRank() <= 1) 564 return failure(); 565 if (!vectorType.getElementType().isSignlessIntOrFloat()) 566 return failure(); 567 unsigned trailingVectorDimBitwidth = 568 vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); 569 if (trailingVectorDimBitwidth >= targetVectorBitwidth) 570 return failure(); 571 if (!vector::isContiguousSlice(sourceType, vectorType)) 572 return failure(); 573 // TODO: generalize this pattern, relax the requirements here. 574 if (transferReadOp.hasOutOfBoundsDim()) 575 return failure(); 576 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 577 return failure(); 578 if (transferReadOp.getMask()) 579 return failure(); 580 581 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); 582 583 // 1. Collapse the source memref 584 Value collapsedSource = 585 collapseInnerDims(rewriter, loc, source, firstDimToCollapse); 586 MemRefType collapsedSourceType = 587 dyn_cast<MemRefType>(collapsedSource.getType()); 588 int64_t collapsedRank = collapsedSourceType.getRank(); 589 assert(collapsedRank == firstDimToCollapse + 1); 590 591 // 2. Generate input args for a new vector.transfer_read that will read 592 // from the collapsed memref. 593 // 2.1. New dim exprs + affine map 594 SmallVector<AffineExpr, 1> dimExprs{ 595 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; 596 auto collapsedMap = 597 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 598 599 // 2.2 New indices 600 // If all the collapsed indices are zero then no extra logic is needed. 601 // Otherwise, a new offset/index has to be computed. 602 SmallVector<Value> collapsedIndices; 603 if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), 604 firstDimToCollapse, 605 collapsedIndices))) { 606 // Copy all the leading indices. 607 SmallVector<Value> indices = transferReadOp.getIndices(); 608 collapsedIndices.append(indices.begin(), 609 indices.begin() + firstDimToCollapse); 610 611 // Compute the remaining trailing index/offset required for reading from 612 // the collapsed memref: 613 // 614 // offset = 0 615 // for (i = firstDimToCollapse; i < outputRank; ++i) 616 // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] 617 // 618 // For this example: 619 // %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) : 620 // memref<1x43x2xi32>, vector<1x2xi32> 621 // which would be collapsed to: 622 // %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) : 623 // memref<1x86xi32>, vector<2xi32> 624 // one would get the following offset: 625 // %offset = %arg0 * 43 626 OpFoldResult collapsedOffset = 627 rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult(); 628 629 auto sourceShape = sourceType.getShape(); 630 auto collapsedStrides = computeSuffixProduct(ArrayRef<int64_t>( 631 sourceShape.begin() + firstDimToCollapse, sourceShape.end())); 632 633 // Compute the collapsed offset. 634 ArrayRef<Value> indicesToCollapse(indices.begin() + firstDimToCollapse, 635 indices.end()); 636 auto &&[collapsedExpr, collapsedVals] = computeLinearIndex( 637 collapsedOffset, collapsedStrides, indicesToCollapse); 638 collapsedOffset = affine::makeComposedFoldedAffineApply( 639 rewriter, loc, collapsedExpr, collapsedVals); 640 641 if (collapsedOffset.is<Value>()) { 642 collapsedIndices.push_back(collapsedOffset.get<Value>()); 643 } else { 644 collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>( 645 loc, *getConstantIntValue(collapsedOffset))); 646 } 647 } 648 649 // 3. Create new vector.transfer_read that reads from the collapsed memref 650 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 651 vectorType.getElementType()); 652 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( 653 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); 654 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 655 656 // 4. Replace the old transfer_read with the new one reading from the 657 // collapsed shape 658 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 659 transferReadOp, cast<VectorType>(vector.getType()), flatRead); 660 return success(); 661 } 662 663 private: 664 // Minimum bitwidth that the trailing vector dimension should have after 665 // flattening. 666 unsigned targetVectorBitwidth; 667 }; 668 669 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 670 /// memref.collapse_shape on the source so that the resulting 671 /// vector.transfer_write has a 1D source. Requires the source shape to be 672 /// already reduced i.e. without unit dims. 673 class FlattenContiguousRowMajorTransferWritePattern 674 : public OpRewritePattern<vector::TransferWriteOp> { 675 public: 676 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, 677 unsigned vectorBitwidth, 678 PatternBenefit benefit) 679 : OpRewritePattern<vector::TransferWriteOp>(context, benefit), 680 targetVectorBitwidth(vectorBitwidth) {} 681 682 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 683 PatternRewriter &rewriter) const override { 684 auto loc = transferWriteOp.getLoc(); 685 Value vector = transferWriteOp.getVector(); 686 VectorType vectorType = cast<VectorType>(vector.getType()); 687 Value source = transferWriteOp.getSource(); 688 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 689 // Contiguity check is valid on tensors only. 690 if (!sourceType) 691 return failure(); 692 if (vectorType.getRank() <= 1) 693 // Already 0D/1D, nothing to do. 694 return failure(); 695 if (!vectorType.getElementType().isSignlessIntOrFloat()) 696 return failure(); 697 unsigned trailingVectorDimBitwidth = 698 vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); 699 if (trailingVectorDimBitwidth >= targetVectorBitwidth) 700 return failure(); 701 if (!vector::isContiguousSlice(sourceType, vectorType)) 702 return failure(); 703 int64_t firstContiguousInnerDim = 704 sourceType.getRank() - vectorType.getRank(); 705 // TODO: generalize this pattern, relax the requirements here. 706 if (transferWriteOp.hasOutOfBoundsDim()) 707 return failure(); 708 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 709 return failure(); 710 if (transferWriteOp.getMask()) 711 return failure(); 712 SmallVector<Value> collapsedIndices; 713 if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(), 714 firstContiguousInnerDim, 715 collapsedIndices))) 716 return failure(); 717 718 Value collapsedSource = 719 collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); 720 MemRefType collapsedSourceType = 721 cast<MemRefType>(collapsedSource.getType()); 722 int64_t collapsedRank = collapsedSourceType.getRank(); 723 assert(collapsedRank == firstContiguousInnerDim + 1); 724 SmallVector<AffineExpr, 1> dimExprs{ 725 getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; 726 auto collapsedMap = 727 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 728 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 729 vectorType.getElementType()); 730 Value flatVector = 731 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); 732 vector::TransferWriteOp flatWrite = 733 rewriter.create<vector::TransferWriteOp>( 734 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); 735 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 736 rewriter.eraseOp(transferWriteOp); 737 return success(); 738 } 739 740 private: 741 // Minimum bitwidth that the trailing vector dimension should have after 742 // flattening. 743 unsigned targetVectorBitwidth; 744 }; 745 746 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` 747 /// to `memref.load` patterns. The `match` method is shared for both 748 /// `vector.extract` and `vector.extract_element`. 749 template <class VectorExtractOp> 750 class RewriteScalarExtractOfTransferReadBase 751 : public OpRewritePattern<VectorExtractOp> { 752 using Base = OpRewritePattern<VectorExtractOp>; 753 754 public: 755 RewriteScalarExtractOfTransferReadBase(MLIRContext *context, 756 PatternBenefit benefit, 757 bool allowMultipleUses) 758 : Base::OpRewritePattern(context, benefit), 759 allowMultipleUses(allowMultipleUses) {} 760 761 LogicalResult match(VectorExtractOp extractOp) const override { 762 auto xferOp = 763 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); 764 if (!xferOp) 765 return failure(); 766 // Check that we are extracting a scalar and not a sub-vector. 767 if (isa<VectorType>(extractOp.getResult().getType())) 768 return failure(); 769 // If multiple uses are not allowed, check if xfer has a single use. 770 if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) 771 return failure(); 772 // If multiple uses are allowed, check if all the xfer uses are extract ops. 773 if (allowMultipleUses && 774 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { 775 return isa<vector::ExtractOp, vector::ExtractElementOp>( 776 use.getOwner()); 777 })) 778 return failure(); 779 // Mask not supported. 780 if (xferOp.getMask()) 781 return failure(); 782 // Map not supported. 783 if (!xferOp.getPermutationMap().isMinorIdentity()) 784 return failure(); 785 // Cannot rewrite if the indices may be out of bounds. 786 if (xferOp.hasOutOfBoundsDim()) 787 return failure(); 788 return success(); 789 } 790 791 private: 792 bool allowMultipleUses; 793 }; 794 795 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 796 /// 797 /// All the users of the transfer op must be either `vector.extractelement` or 798 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 799 /// transfer ops with any number of users. Otherwise, rewrite only if the 800 /// extract op is the single user of the transfer op. Rewriting a single 801 /// vector load with multiple scalar loads may negatively affect performance. 802 class RewriteScalarExtractElementOfTransferRead 803 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { 804 using RewriteScalarExtractOfTransferReadBase:: 805 RewriteScalarExtractOfTransferReadBase; 806 807 void rewrite(vector::ExtractElementOp extractOp, 808 PatternRewriter &rewriter) const override { 809 // Construct scalar load. 810 auto loc = extractOp.getLoc(); 811 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 812 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 813 xferOp.getIndices().end()); 814 if (extractOp.getPosition()) { 815 AffineExpr sym0, sym1; 816 bindSymbols(extractOp.getContext(), sym0, sym1); 817 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 818 rewriter, loc, sym0 + sym1, 819 {newIndices[newIndices.size() - 1], extractOp.getPosition()}); 820 if (ofr.is<Value>()) { 821 newIndices[newIndices.size() - 1] = ofr.get<Value>(); 822 } else { 823 newIndices[newIndices.size() - 1] = 824 rewriter.create<arith::ConstantIndexOp>(loc, 825 *getConstantIntValue(ofr)); 826 } 827 } 828 if (isa<MemRefType>(xferOp.getSource().getType())) { 829 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 830 newIndices); 831 } else { 832 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 833 extractOp, xferOp.getSource(), newIndices); 834 } 835 } 836 }; 837 838 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 839 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. 840 /// 841 /// All the users of the transfer op must be either `vector.extractelement` or 842 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 843 /// transfer ops with any number of users. Otherwise, rewrite only if the 844 /// extract op is the single user of the transfer op. Rewriting a single 845 /// vector load with multiple scalar loads may negatively affect performance. 846 class RewriteScalarExtractOfTransferRead 847 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { 848 using RewriteScalarExtractOfTransferReadBase:: 849 RewriteScalarExtractOfTransferReadBase; 850 851 void rewrite(vector::ExtractOp extractOp, 852 PatternRewriter &rewriter) const override { 853 // Construct scalar load. 854 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 855 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 856 xferOp.getIndices().end()); 857 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 858 assert(pos.is<Attribute>() && "Unexpected non-constant index"); 859 int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt(); 860 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; 861 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 862 rewriter, extractOp.getLoc(), 863 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); 864 if (ofr.is<Value>()) { 865 newIndices[idx] = ofr.get<Value>(); 866 } else { 867 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( 868 extractOp.getLoc(), *getConstantIntValue(ofr)); 869 } 870 } 871 if (isa<MemRefType>(xferOp.getSource().getType())) { 872 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 873 newIndices); 874 } else { 875 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 876 extractOp, xferOp.getSource(), newIndices); 877 } 878 } 879 }; 880 881 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) 882 /// to memref.store. 883 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { 884 using OpRewritePattern::OpRewritePattern; 885 886 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 887 PatternRewriter &rewriter) const override { 888 // Must be a scalar write. 889 auto vecType = xferOp.getVectorType(); 890 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) 891 return failure(); 892 // Mask not supported. 893 if (xferOp.getMask()) 894 return failure(); 895 // Map not supported. 896 if (!xferOp.getPermutationMap().isMinorIdentity()) 897 return failure(); 898 // Only float and integer element types are supported. 899 Value scalar; 900 if (vecType.getRank() == 0) { 901 // vector.extract does not support vector<f32> etc., so use 902 // vector.extractelement instead. 903 scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(), 904 xferOp.getVector()); 905 } else { 906 SmallVector<int64_t> pos(vecType.getRank(), 0); 907 scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(), 908 xferOp.getVector(), pos); 909 } 910 // Construct a scalar store. 911 if (isa<MemRefType>(xferOp.getSource().getType())) { 912 rewriter.replaceOpWithNewOp<memref::StoreOp>( 913 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 914 } else { 915 rewriter.replaceOpWithNewOp<tensor::InsertOp>( 916 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 917 } 918 return success(); 919 } 920 }; 921 922 } // namespace 923 924 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, 925 Operation *rootOp) { 926 TransferOptimization opt(rewriter, rootOp); 927 // Run store to load forwarding first since it can expose more dead store 928 // opportunity. 929 rootOp->walk([&](vector::TransferReadOp read) { 930 if (isa<MemRefType>(read.getShapedType())) 931 opt.storeToLoadForwarding(read); 932 }); 933 opt.removeDeadOp(); 934 rootOp->walk([&](vector::TransferWriteOp write) { 935 if (isa<MemRefType>(write.getShapedType())) 936 opt.deadStoreOp(write); 937 }); 938 opt.removeDeadOp(); 939 } 940 941 void mlir::vector::populateScalarVectorTransferLoweringPatterns( 942 RewritePatternSet &patterns, PatternBenefit benefit, 943 bool allowMultipleUses) { 944 patterns.add<RewriteScalarExtractElementOfTransferRead, 945 RewriteScalarExtractOfTransferRead>(patterns.getContext(), 946 benefit, allowMultipleUses); 947 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit); 948 } 949 950 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 951 RewritePatternSet &patterns, PatternBenefit benefit) { 952 patterns 953 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 954 patterns.getContext(), benefit); 955 populateShapeCastFoldingPatterns(patterns); 956 } 957 958 void mlir::vector::populateFlattenVectorTransferPatterns( 959 RewritePatternSet &patterns, unsigned targetVectorBitwidth, 960 PatternBenefit benefit) { 961 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 962 FlattenContiguousRowMajorTransferWritePattern>( 963 patterns.getContext(), targetVectorBitwidth, benefit); 964 populateShapeCastFoldingPatterns(patterns, benefit); 965 populateDropUnitDimWithShapeCastPatterns(patterns, benefit); 966 } 967