1 //===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with optimizing transfer_read and 10 // transfer_write ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arith/IR/Arith.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 18 #include "mlir/Dialect/Tensor/IR/Tensor.h" 19 #include "mlir/Dialect/Utils/IndexingUtils.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 23 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 24 #include "mlir/IR/Dominance.h" 25 #include "mlir/Interfaces/SideEffectInterfaces.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/StringRef.h" 28 #include "llvm/Support/Debug.h" 29 30 #define DEBUG_TYPE "vector-transfer-opt" 31 32 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 33 34 using namespace mlir; 35 36 /// Return the ancestor op in the region or nullptr if the region is not 37 /// an ancestor of the op. 38 static Operation *findAncestorOpInRegion(Region *region, Operation *op) { 39 for (; op != nullptr && op->getParentRegion() != region; 40 op = op->getParentOp()) 41 ; 42 return op; 43 } 44 45 namespace { 46 47 class TransferOptimization { 48 public: 49 TransferOptimization(RewriterBase &rewriter, Operation *op) 50 : rewriter(rewriter), dominators(op), postDominators(op) {} 51 void deadStoreOp(vector::TransferWriteOp); 52 void storeToLoadForwarding(vector::TransferReadOp); 53 void removeDeadOp() { 54 for (Operation *op : opToErase) 55 rewriter.eraseOp(op); 56 opToErase.clear(); 57 } 58 59 private: 60 RewriterBase &rewriter; 61 bool isReachable(Operation *start, Operation *dest); 62 DominanceInfo dominators; 63 PostDominanceInfo postDominators; 64 std::vector<Operation *> opToErase; 65 }; 66 67 } // namespace 68 /// Return true if there is a path from start operation to dest operation, 69 /// otherwise return false. The operations have to be in the same region. 70 bool TransferOptimization::isReachable(Operation *start, Operation *dest) { 71 assert(start->getParentRegion() == dest->getParentRegion() && 72 "This function only works for ops i the same region"); 73 // Simple case where the start op dominate the destination. 74 if (dominators.dominates(start, dest)) 75 return true; 76 Block *startBlock = start->getBlock(); 77 Block *destBlock = dest->getBlock(); 78 SmallVector<Block *, 32> worklist(startBlock->succ_begin(), 79 startBlock->succ_end()); 80 SmallPtrSet<Block *, 32> visited; 81 while (!worklist.empty()) { 82 Block *bb = worklist.pop_back_val(); 83 if (!visited.insert(bb).second) 84 continue; 85 if (dominators.dominates(bb, destBlock)) 86 return true; 87 worklist.append(bb->succ_begin(), bb->succ_end()); 88 } 89 return false; 90 } 91 92 /// For transfer_write to overwrite fully another transfer_write must: 93 /// 1. Access the same memref with the same indices and vector type. 94 /// 2. Post-dominate the other transfer_write operation. 95 /// If several candidates are available, one must be post-dominated by all the 96 /// others since they are all post-dominating the same transfer_write. We only 97 /// consider the transfer_write post-dominated by all the other candidates as 98 /// this will be the first transfer_write executed after the potentially dead 99 /// transfer_write. 100 /// If we found such an overwriting transfer_write we know that the original 101 /// transfer_write is dead if all reads that can be reached from the potentially 102 /// dead transfer_write are dominated by the overwriting transfer_write. 103 void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { 104 LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() 105 << "\n"); 106 llvm::SmallVector<Operation *, 8> blockingAccesses; 107 Operation *firstOverwriteCandidate = nullptr; 108 Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource())); 109 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 110 source.getUsers().end()); 111 llvm::SmallDenseSet<Operation *, 32> processed; 112 while (!users.empty()) { 113 Operation *user = users.pop_back_val(); 114 // If the user has already been processed skip. 115 if (!processed.insert(user).second) 116 continue; 117 if (isa<ViewLikeOpInterface>(user)) { 118 users.append(user->getUsers().begin(), user->getUsers().end()); 119 continue; 120 } 121 if (isMemoryEffectFree(user)) 122 continue; 123 if (user == write.getOperation()) 124 continue; 125 if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) { 126 // Check candidate that can override the store. 127 if (memref::isSameViewOrTrivialAlias( 128 cast<MemrefValue>(nextWrite.getSource()), 129 cast<MemrefValue>(write.getSource())) && 130 checkSameValueWAW(nextWrite, write) && 131 postDominators.postDominates(nextWrite, write)) { 132 if (firstOverwriteCandidate == nullptr || 133 postDominators.postDominates(firstOverwriteCandidate, nextWrite)) 134 firstOverwriteCandidate = nextWrite; 135 else 136 assert( 137 postDominators.postDominates(nextWrite, firstOverwriteCandidate)); 138 continue; 139 } 140 } 141 if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) { 142 // Don't need to consider disjoint accesses. 143 if (vector::isDisjointTransferSet( 144 cast<VectorTransferOpInterface>(write.getOperation()), 145 cast<VectorTransferOpInterface>(transferOp.getOperation()), 146 /*testDynamicValueUsingBounds=*/true)) 147 continue; 148 } 149 blockingAccesses.push_back(user); 150 } 151 if (firstOverwriteCandidate == nullptr) 152 return; 153 Region *topRegion = firstOverwriteCandidate->getParentRegion(); 154 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 155 assert(writeAncestor && 156 "write op should be recursively part of the top region"); 157 158 for (Operation *access : blockingAccesses) { 159 Operation *accessAncestor = findAncestorOpInRegion(topRegion, access); 160 // TODO: if the access and write have the same ancestor we could recurse in 161 // the region to know if the access is reachable with more precision. 162 if (accessAncestor == nullptr || 163 !isReachable(writeAncestor, accessAncestor)) 164 continue; 165 if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { 166 LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " 167 << *accessAncestor << "\n"); 168 return; 169 } 170 } 171 LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() 172 << " overwritten by: " << *firstOverwriteCandidate << "\n"); 173 opToErase.push_back(write.getOperation()); 174 } 175 176 /// A transfer_write candidate to storeToLoad forwarding must: 177 /// 1. Access the same memref with the same indices and vector type as the 178 /// transfer_read. 179 /// 2. Dominate the transfer_read operation. 180 /// If several candidates are available, one must be dominated by all the others 181 /// since they are all dominating the same transfer_read. We only consider the 182 /// transfer_write dominated by all the other candidates as this will be the 183 /// last transfer_write executed before the transfer_read. 184 /// If we found such a candidate we can do the forwarding if all the other 185 /// potentially aliasing ops that may reach the transfer_read are post-dominated 186 /// by the transfer_write. 187 void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { 188 if (read.hasOutOfBoundsDim()) 189 return; 190 LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() 191 << "\n"); 192 SmallVector<Operation *, 8> blockingWrites; 193 vector::TransferWriteOp lastwrite = nullptr; 194 Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource())); 195 llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), 196 source.getUsers().end()); 197 llvm::SmallDenseSet<Operation *, 32> processed; 198 while (!users.empty()) { 199 Operation *user = users.pop_back_val(); 200 // If the user has already been processed skip. 201 if (!processed.insert(user).second) 202 continue; 203 if (isa<ViewLikeOpInterface>(user)) { 204 users.append(user->getUsers().begin(), user->getUsers().end()); 205 continue; 206 } 207 if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) 208 continue; 209 if (auto write = dyn_cast<vector::TransferWriteOp>(user)) { 210 // If there is a write, but we can prove that it is disjoint we can ignore 211 // the write. 212 if (vector::isDisjointTransferSet( 213 cast<VectorTransferOpInterface>(write.getOperation()), 214 cast<VectorTransferOpInterface>(read.getOperation()), 215 /*testDynamicValueUsingBounds=*/true)) 216 continue; 217 if (memref::isSameViewOrTrivialAlias( 218 cast<MemrefValue>(read.getSource()), 219 cast<MemrefValue>(write.getSource())) && 220 dominators.dominates(write, read) && checkSameValueRAW(write, read)) { 221 if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) 222 lastwrite = write; 223 else 224 assert(dominators.dominates(write, lastwrite)); 225 continue; 226 } 227 } 228 blockingWrites.push_back(user); 229 } 230 231 if (lastwrite == nullptr) 232 return; 233 234 Region *topRegion = lastwrite->getParentRegion(); 235 Operation *readAncestor = findAncestorOpInRegion(topRegion, read); 236 assert(readAncestor && 237 "read op should be recursively part of the top region"); 238 239 for (Operation *write : blockingWrites) { 240 Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); 241 // TODO: if the store and read have the same ancestor we could recurse in 242 // the region to know if the read is reachable with more precision. 243 if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) 244 continue; 245 if (!postDominators.postDominates(lastwrite, write)) { 246 LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " 247 << *write << "\n"); 248 return; 249 } 250 } 251 252 LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() 253 << " to: " << *read.getOperation() << "\n"); 254 read.replaceAllUsesWith(lastwrite.getVector()); 255 opToErase.push_back(read.getOperation()); 256 } 257 258 /// Converts OpFoldResults to int64_t shape without unit dims. 259 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { 260 SmallVector<int64_t> reducedShape; 261 for (const auto size : mixedSizes) { 262 if (llvm::dyn_cast_if_present<Value>(size)) { 263 reducedShape.push_back(ShapedType::kDynamic); 264 continue; 265 } 266 267 auto value = cast<IntegerAttr>(size.get<Attribute>()).getValue(); 268 if (value == 1) 269 continue; 270 reducedShape.push_back(value.getSExtValue()); 271 } 272 return reducedShape; 273 } 274 275 /// Drops unit dimensions from the input MemRefType. 276 static MemRefType dropUnitDims(MemRefType inputType, 277 ArrayRef<OpFoldResult> offsets, 278 ArrayRef<OpFoldResult> sizes, 279 ArrayRef<OpFoldResult> strides) { 280 auto targetShape = getReducedShape(sizes); 281 Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( 282 targetShape, inputType, offsets, sizes, strides); 283 return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType)); 284 } 285 286 /// Creates a rank-reducing memref.subview op that drops unit dims from its 287 /// input. Or just returns the input if it was already without unit dims. 288 static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, 289 mlir::Location loc, 290 Value input) { 291 MemRefType inputType = cast<MemRefType>(input.getType()); 292 SmallVector<OpFoldResult> offsets(inputType.getRank(), 293 rewriter.getIndexAttr(0)); 294 SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input); 295 SmallVector<OpFoldResult> strides(inputType.getRank(), 296 rewriter.getIndexAttr(1)); 297 MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); 298 299 if (canonicalizeStridedLayout(resultType) == 300 canonicalizeStridedLayout(inputType)) 301 return input; 302 return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets, 303 sizes, strides); 304 } 305 306 /// Returns the number of dims that aren't unit dims. 307 static int getReducedRank(ArrayRef<int64_t> shape) { 308 return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); 309 } 310 311 /// Trims non-scalable one dimensions from `oldType` and returns the result 312 /// type. 313 static VectorType trimNonScalableUnitDims(VectorType oldType) { 314 SmallVector<int64_t> newShape; 315 SmallVector<bool> newScalableDims; 316 for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) { 317 if (dimSize == 1 && !oldType.getScalableDims()[dimIdx]) 318 continue; 319 newShape.push_back(dimSize); 320 newScalableDims.push_back(oldType.getScalableDims()[dimIdx]); 321 } 322 return VectorType::get(newShape, oldType.getElementType(), newScalableDims); 323 } 324 325 // Rewrites vector.create_mask 'op' to drop non-scalable one dimensions. 326 static FailureOr<Value> 327 createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, 328 vector::CreateMaskOp op) { 329 auto type = op.getType(); 330 VectorType reducedType = trimNonScalableUnitDims(type); 331 if (reducedType.getRank() == type.getRank()) 332 return failure(); 333 334 SmallVector<Value> reducedOperands; 335 for (auto [dim, dimIsScalable, operand] : llvm::zip_equal( 336 type.getShape(), type.getScalableDims(), op.getOperands())) { 337 if (dim == 1 && !dimIsScalable) { 338 // If the mask for the unit dim is not a constant of 1, do nothing. 339 auto constant = operand.getDefiningOp<arith::ConstantIndexOp>(); 340 if (!constant || (constant.value() != 1)) 341 return failure(); 342 continue; 343 } 344 reducedOperands.push_back(operand); 345 } 346 return rewriter 347 .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) 348 .getResult(); 349 } 350 351 namespace { 352 353 /// Rewrites `vector.transfer_read` ops where the source has unit dims, by 354 /// inserting a memref.subview dropping those unit dims. The vector shapes are 355 /// also reduced accordingly. 356 class TransferReadDropUnitDimsPattern 357 : public vector::MaskableOpRewritePattern<vector::TransferReadOp> { 358 using MaskableOpRewritePattern::MaskableOpRewritePattern; 359 360 FailureOr<Value> 361 matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp, 362 vector::MaskingOpInterface maskingOp, 363 PatternRewriter &rewriter) const override { 364 auto loc = transferReadOp.getLoc(); 365 Value vector = transferReadOp.getVector(); 366 VectorType vectorType = cast<VectorType>(vector.getType()); 367 Value source = transferReadOp.getSource(); 368 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 369 // TODO: support tensor types. 370 if (!sourceType) 371 return failure(); 372 // TODO: generalize this pattern, relax the requirements here. 373 if (transferReadOp.hasOutOfBoundsDim()) 374 return failure(); 375 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 376 return failure(); 377 // Check if the source shape can be further reduced. 378 int reducedRank = getReducedRank(sourceType.getShape()); 379 if (reducedRank == sourceType.getRank()) 380 return failure(); 381 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail 382 // out. 383 if (reducedRank == 0 && maskingOp) 384 return failure(); 385 // Check if the reduced vector shape matches the reduced source shape. 386 // Otherwise, this case is not supported yet. 387 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 388 if (reducedRank != reducedVectorType.getRank()) 389 return failure(); 390 if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { 391 return getConstantIntValue(v) != static_cast<int64_t>(0); 392 })) 393 return failure(); 394 395 Value maskOp = transferReadOp.getMask(); 396 if (maskOp) { 397 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 398 if (!createMaskOp) 399 return rewriter.notifyMatchFailure( 400 transferReadOp, "unsupported mask op, only 'vector.create_mask' is " 401 "currently supported"); 402 FailureOr<Value> rankReducedCreateMask = 403 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 404 if (failed(rankReducedCreateMask)) 405 return failure(); 406 maskOp = *rankReducedCreateMask; 407 } 408 409 Value reducedShapeSource = 410 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 411 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 412 SmallVector<Value> zeros(reducedRank, c0); 413 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 414 SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 415 Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>( 416 loc, reducedVectorType, reducedShapeSource, zeros, identityMap, 417 transferReadOp.getPadding(), maskOp, 418 rewriter.getBoolArrayAttr(inBounds)); 419 420 if (maskingOp) { 421 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( 422 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), 423 maskingOp.getMask()); 424 newTransferReadOp = mlir::vector::maskOperation( 425 rewriter, newTransferReadOp, shapeCastMask); 426 } 427 428 auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( 429 loc, vectorType, newTransferReadOp->getResults()[0]); 430 431 return shapeCast; 432 } 433 }; 434 435 /// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) 436 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The 437 /// vector shapes are also reduced accordingly. 438 class TransferWriteDropUnitDimsPattern 439 : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> { 440 using MaskableOpRewritePattern::MaskableOpRewritePattern; 441 442 FailureOr<Value> 443 matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp, 444 vector::MaskingOpInterface maskingOp, 445 PatternRewriter &rewriter) const override { 446 auto loc = transferWriteOp.getLoc(); 447 Value vector = transferWriteOp.getVector(); 448 VectorType vectorType = cast<VectorType>(vector.getType()); 449 Value source = transferWriteOp.getSource(); 450 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 451 // TODO: support tensor type. 452 if (!sourceType) 453 return failure(); 454 // TODO: generalize this pattern, relax the requirements here. 455 if (transferWriteOp.hasOutOfBoundsDim()) 456 return failure(); 457 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 458 return failure(); 459 // Check if the destination shape can be further reduced. 460 int reducedRank = getReducedRank(sourceType.getShape()); 461 if (reducedRank == sourceType.getRank()) 462 return failure(); 463 // TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail 464 // out. 465 if (reducedRank == 0 && maskingOp) 466 return failure(); 467 // Check if the reduced vector shape matches the reduced destination shape. 468 // Otherwise, this case is not supported yet. 469 VectorType reducedVectorType = trimNonScalableUnitDims(vectorType); 470 if (reducedRank != reducedVectorType.getRank()) 471 return failure(); 472 if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { 473 return getConstantIntValue(v) != static_cast<int64_t>(0); 474 })) 475 return failure(); 476 477 Value maskOp = transferWriteOp.getMask(); 478 if (maskOp) { 479 auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); 480 if (!createMaskOp) 481 return rewriter.notifyMatchFailure( 482 transferWriteOp, 483 "unsupported mask op, only 'vector.create_mask' is " 484 "currently supported"); 485 FailureOr<Value> rankReducedCreateMask = 486 createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); 487 if (failed(rankReducedCreateMask)) 488 return failure(); 489 maskOp = *rankReducedCreateMask; 490 } 491 Value reducedShapeSource = 492 rankReducingSubviewDroppingUnitDims(rewriter, loc, source); 493 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 494 SmallVector<Value> zeros(reducedRank, c0); 495 auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); 496 SmallVector<bool> inBounds(reducedVectorType.getRank(), true); 497 auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>( 498 loc, reducedVectorType, vector); 499 Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>( 500 loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap, 501 maskOp, rewriter.getBoolArrayAttr(inBounds)); 502 503 if (maskingOp) { 504 auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( 505 loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()), 506 maskingOp.getMask()); 507 newXferWrite = 508 mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask); 509 } 510 511 if (transferWriteOp.hasPureTensorSemantics()) 512 return newXferWrite->getResults()[0]; 513 514 // With Memref semantics, there's no return value. Use empty value to signal 515 // success. 516 return Value(); 517 } 518 }; 519 520 } // namespace 521 522 /// Creates a memref.collapse_shape collapsing all inner dimensions of the 523 /// input starting at `firstDimToCollapse`. 524 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, 525 Value input, int64_t firstDimToCollapse) { 526 ShapedType inputType = cast<ShapedType>(input.getType()); 527 if (inputType.getRank() == 1) 528 return input; 529 SmallVector<ReassociationIndices> reassociation; 530 for (int64_t i = 0; i < firstDimToCollapse; ++i) 531 reassociation.push_back(ReassociationIndices{i}); 532 ReassociationIndices collapsedIndices; 533 for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) 534 collapsedIndices.push_back(i); 535 reassociation.push_back(collapsedIndices); 536 return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation); 537 } 538 539 /// Returns the new indices that collapses the inner dimensions starting from 540 /// the `firstDimToCollapse` dimension. 541 static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter, 542 Location loc, 543 ArrayRef<int64_t> shape, 544 ValueRange indices, 545 int64_t firstDimToCollapse) { 546 assert(firstDimToCollapse < static_cast<int64_t>(indices.size())); 547 548 // If all the collapsed indices are zero then no extra logic is needed. 549 // Otherwise, a new offset/index has to be computed. 550 SmallVector<Value> indicesAfterCollapsing( 551 indices.begin(), indices.begin() + firstDimToCollapse); 552 SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse, 553 indices.end()); 554 if (llvm::all_of(indicesToCollapse, isZeroIndex)) { 555 indicesAfterCollapsing.push_back(indicesToCollapse[0]); 556 return indicesAfterCollapsing; 557 } 558 559 // Compute the remaining trailing index/offset required for reading from 560 // the collapsed memref: 561 // 562 // offset = 0 563 // for (i = firstDimToCollapse; i < outputRank; ++i) 564 // offset += sourceType.getDimSize(i) * transferReadOp.indices[i] 565 // 566 // For this example: 567 // %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) : 568 // memref<1x43x2xi32>, vector<1x2xi32> 569 // which would be collapsed to: 570 // %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) : 571 // memref<1x86xi32>, vector<2xi32> 572 // one would get the following offset: 573 // %offset = %arg0 * 43 574 OpFoldResult collapsedOffset = 575 rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult(); 576 577 auto collapsedStrides = computeSuffixProduct( 578 ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end())); 579 580 // Compute the collapsed offset. 581 auto &&[collapsedExpr, collapsedVals] = 582 computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse); 583 collapsedOffset = affine::makeComposedFoldedAffineApply( 584 rewriter, loc, collapsedExpr, collapsedVals); 585 586 if (collapsedOffset.is<Value>()) { 587 indicesAfterCollapsing.push_back(collapsedOffset.get<Value>()); 588 } else { 589 indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>( 590 loc, *getConstantIntValue(collapsedOffset))); 591 } 592 593 return indicesAfterCollapsing; 594 } 595 596 namespace { 597 598 /// Rewrites contiguous row-major vector.transfer_read ops by inserting 599 /// memref.collapse_shape on the source so that the resulting 600 /// vector.transfer_read has a 1D source. Requires the source shape to be 601 /// already reduced i.e. without unit dims. 602 /// 603 /// If `targetVectorBitwidth` is provided, the flattening will only happen if 604 /// the trailing dimension of the vector read is smaller than the provided 605 /// bitwidth. 606 class FlattenContiguousRowMajorTransferReadPattern 607 : public OpRewritePattern<vector::TransferReadOp> { 608 public: 609 FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, 610 unsigned vectorBitwidth, 611 PatternBenefit benefit) 612 : OpRewritePattern<vector::TransferReadOp>(context, benefit), 613 targetVectorBitwidth(vectorBitwidth) {} 614 615 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 616 PatternRewriter &rewriter) const override { 617 auto loc = transferReadOp.getLoc(); 618 Value vector = transferReadOp.getVector(); 619 VectorType vectorType = cast<VectorType>(vector.getType()); 620 auto source = transferReadOp.getSource(); 621 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 622 623 // 0. Check pre-conditions 624 // Contiguity check is valid on tensors only. 625 if (!sourceType) 626 return failure(); 627 // If this is already 0D/1D, there's nothing to do. 628 if (vectorType.getRank() <= 1) 629 return failure(); 630 if (!vectorType.getElementType().isSignlessIntOrFloat()) 631 return failure(); 632 unsigned trailingVectorDimBitwidth = 633 vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); 634 if (trailingVectorDimBitwidth >= targetVectorBitwidth) 635 return failure(); 636 if (!vector::isContiguousSlice(sourceType, vectorType)) 637 return failure(); 638 // TODO: generalize this pattern, relax the requirements here. 639 if (transferReadOp.hasOutOfBoundsDim()) 640 return failure(); 641 if (!transferReadOp.getPermutationMap().isMinorIdentity()) 642 return failure(); 643 if (transferReadOp.getMask()) 644 return failure(); 645 646 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); 647 648 // 1. Collapse the source memref 649 Value collapsedSource = 650 collapseInnerDims(rewriter, loc, source, firstDimToCollapse); 651 MemRefType collapsedSourceType = 652 cast<MemRefType>(collapsedSource.getType()); 653 int64_t collapsedRank = collapsedSourceType.getRank(); 654 assert(collapsedRank == firstDimToCollapse + 1); 655 656 // 2. Generate input args for a new vector.transfer_read that will read 657 // from the collapsed memref. 658 // 2.1. New dim exprs + affine map 659 SmallVector<AffineExpr, 1> dimExprs{ 660 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; 661 auto collapsedMap = 662 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 663 664 // 2.2 New indices 665 SmallVector<Value> collapsedIndices = 666 getCollapsedIndices(rewriter, loc, sourceType.getShape(), 667 transferReadOp.getIndices(), firstDimToCollapse); 668 669 // 3. Create new vector.transfer_read that reads from the collapsed memref 670 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 671 vectorType.getElementType()); 672 vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>( 673 loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); 674 flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 675 676 // 4. Replace the old transfer_read with the new one reading from the 677 // collapsed shape 678 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>( 679 transferReadOp, cast<VectorType>(vector.getType()), flatRead); 680 return success(); 681 } 682 683 private: 684 // Minimum bitwidth that the trailing vector dimension should have after 685 // flattening. 686 unsigned targetVectorBitwidth; 687 }; 688 689 /// Rewrites contiguous row-major vector.transfer_write ops by inserting 690 /// memref.collapse_shape on the source so that the resulting 691 /// vector.transfer_write has a 1D source. Requires the source shape to be 692 /// already reduced i.e. without unit dims. 693 /// 694 /// If `targetVectorBitwidth` is provided, the flattening will only happen if 695 /// the trailing dimension of the vector read is smaller than the provided 696 /// bitwidth. 697 class FlattenContiguousRowMajorTransferWritePattern 698 : public OpRewritePattern<vector::TransferWriteOp> { 699 public: 700 FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, 701 unsigned vectorBitwidth, 702 PatternBenefit benefit) 703 : OpRewritePattern<vector::TransferWriteOp>(context, benefit), 704 targetVectorBitwidth(vectorBitwidth) {} 705 706 LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, 707 PatternRewriter &rewriter) const override { 708 auto loc = transferWriteOp.getLoc(); 709 Value vector = transferWriteOp.getVector(); 710 VectorType vectorType = cast<VectorType>(vector.getType()); 711 Value source = transferWriteOp.getSource(); 712 MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); 713 714 // 0. Check pre-conditions 715 // Contiguity check is valid on tensors only. 716 if (!sourceType) 717 return failure(); 718 // If this is already 0D/1D, there's nothing to do. 719 if (vectorType.getRank() <= 1) 720 // Already 0D/1D, nothing to do. 721 return failure(); 722 if (!vectorType.getElementType().isSignlessIntOrFloat()) 723 return failure(); 724 unsigned trailingVectorDimBitwidth = 725 vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); 726 if (trailingVectorDimBitwidth >= targetVectorBitwidth) 727 return failure(); 728 if (!vector::isContiguousSlice(sourceType, vectorType)) 729 return failure(); 730 // TODO: generalize this pattern, relax the requirements here. 731 if (transferWriteOp.hasOutOfBoundsDim()) 732 return failure(); 733 if (!transferWriteOp.getPermutationMap().isMinorIdentity()) 734 return failure(); 735 if (transferWriteOp.getMask()) 736 return failure(); 737 738 int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); 739 740 // 1. Collapse the source memref 741 Value collapsedSource = 742 collapseInnerDims(rewriter, loc, source, firstDimToCollapse); 743 MemRefType collapsedSourceType = 744 cast<MemRefType>(collapsedSource.getType()); 745 int64_t collapsedRank = collapsedSourceType.getRank(); 746 assert(collapsedRank == firstDimToCollapse + 1); 747 748 // 2. Generate input args for a new vector.transfer_read that will read 749 // from the collapsed memref. 750 // 2.1. New dim exprs + affine map 751 SmallVector<AffineExpr, 1> dimExprs{ 752 getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; 753 auto collapsedMap = 754 AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); 755 756 // 2.2 New indices 757 SmallVector<Value> collapsedIndices = 758 getCollapsedIndices(rewriter, loc, sourceType.getShape(), 759 transferWriteOp.getIndices(), firstDimToCollapse); 760 761 // 3. Create new vector.transfer_write that writes to the collapsed memref 762 VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, 763 vectorType.getElementType()); 764 Value flatVector = 765 rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector); 766 vector::TransferWriteOp flatWrite = 767 rewriter.create<vector::TransferWriteOp>( 768 loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); 769 flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); 770 771 // 4. Replace the old transfer_write with the new one writing the 772 // collapsed shape 773 rewriter.eraseOp(transferWriteOp); 774 return success(); 775 } 776 777 private: 778 // Minimum bitwidth that the trailing vector dimension should have after 779 // flattening. 780 unsigned targetVectorBitwidth; 781 }; 782 783 /// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` 784 /// to `memref.load` patterns. The `match` method is shared for both 785 /// `vector.extract` and `vector.extract_element`. 786 template <class VectorExtractOp> 787 class RewriteScalarExtractOfTransferReadBase 788 : public OpRewritePattern<VectorExtractOp> { 789 using Base = OpRewritePattern<VectorExtractOp>; 790 791 public: 792 RewriteScalarExtractOfTransferReadBase(MLIRContext *context, 793 PatternBenefit benefit, 794 bool allowMultipleUses) 795 : Base::OpRewritePattern(context, benefit), 796 allowMultipleUses(allowMultipleUses) {} 797 798 LogicalResult match(VectorExtractOp extractOp) const override { 799 auto xferOp = 800 extractOp.getVector().template getDefiningOp<vector::TransferReadOp>(); 801 if (!xferOp) 802 return failure(); 803 // Check that we are extracting a scalar and not a sub-vector. 804 if (isa<VectorType>(extractOp.getResult().getType())) 805 return failure(); 806 // If multiple uses are not allowed, check if xfer has a single use. 807 if (!allowMultipleUses && !xferOp.getResult().hasOneUse()) 808 return failure(); 809 // If multiple uses are allowed, check if all the xfer uses are extract ops. 810 if (allowMultipleUses && 811 !llvm::all_of(xferOp->getUses(), [](OpOperand &use) { 812 return isa<vector::ExtractOp, vector::ExtractElementOp>( 813 use.getOwner()); 814 })) 815 return failure(); 816 // Mask not supported. 817 if (xferOp.getMask()) 818 return failure(); 819 // Map not supported. 820 if (!xferOp.getPermutationMap().isMinorIdentity()) 821 return failure(); 822 // Cannot rewrite if the indices may be out of bounds. 823 if (xferOp.hasOutOfBoundsDim()) 824 return failure(); 825 return success(); 826 } 827 828 private: 829 bool allowMultipleUses; 830 }; 831 832 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 833 /// 834 /// All the users of the transfer op must be either `vector.extractelement` or 835 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 836 /// transfer ops with any number of users. Otherwise, rewrite only if the 837 /// extract op is the single user of the transfer op. Rewriting a single 838 /// vector load with multiple scalar loads may negatively affect performance. 839 class RewriteScalarExtractElementOfTransferRead 840 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> { 841 using RewriteScalarExtractOfTransferReadBase:: 842 RewriteScalarExtractOfTransferReadBase; 843 844 void rewrite(vector::ExtractElementOp extractOp, 845 PatternRewriter &rewriter) const override { 846 // Construct scalar load. 847 auto loc = extractOp.getLoc(); 848 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 849 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 850 xferOp.getIndices().end()); 851 if (extractOp.getPosition()) { 852 AffineExpr sym0, sym1; 853 bindSymbols(extractOp.getContext(), sym0, sym1); 854 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 855 rewriter, loc, sym0 + sym1, 856 {newIndices[newIndices.size() - 1], extractOp.getPosition()}); 857 if (ofr.is<Value>()) { 858 newIndices[newIndices.size() - 1] = ofr.get<Value>(); 859 } else { 860 newIndices[newIndices.size() - 1] = 861 rewriter.create<arith::ConstantIndexOp>(loc, 862 *getConstantIntValue(ofr)); 863 } 864 } 865 if (isa<MemRefType>(xferOp.getSource().getType())) { 866 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 867 newIndices); 868 } else { 869 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 870 extractOp, xferOp.getSource(), newIndices); 871 } 872 } 873 }; 874 875 /// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`. 876 /// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`. 877 /// 878 /// All the users of the transfer op must be either `vector.extractelement` or 879 /// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite 880 /// transfer ops with any number of users. Otherwise, rewrite only if the 881 /// extract op is the single user of the transfer op. Rewriting a single 882 /// vector load with multiple scalar loads may negatively affect performance. 883 class RewriteScalarExtractOfTransferRead 884 : public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> { 885 using RewriteScalarExtractOfTransferReadBase:: 886 RewriteScalarExtractOfTransferReadBase; 887 888 void rewrite(vector::ExtractOp extractOp, 889 PatternRewriter &rewriter) const override { 890 // Construct scalar load. 891 auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>(); 892 SmallVector<Value> newIndices(xferOp.getIndices().begin(), 893 xferOp.getIndices().end()); 894 for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) { 895 assert(pos.is<Attribute>() && "Unexpected non-constant index"); 896 int64_t offset = cast<IntegerAttr>(pos.get<Attribute>()).getInt(); 897 int64_t idx = newIndices.size() - extractOp.getNumIndices() + i; 898 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 899 rewriter, extractOp.getLoc(), 900 rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]}); 901 if (ofr.is<Value>()) { 902 newIndices[idx] = ofr.get<Value>(); 903 } else { 904 newIndices[idx] = rewriter.create<arith::ConstantIndexOp>( 905 extractOp.getLoc(), *getConstantIntValue(ofr)); 906 } 907 } 908 if (isa<MemRefType>(xferOp.getSource().getType())) { 909 rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(), 910 newIndices); 911 } else { 912 rewriter.replaceOpWithNewOp<tensor::ExtractOp>( 913 extractOp, xferOp.getSource(), newIndices); 914 } 915 } 916 }; 917 918 /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) 919 /// to memref.store. 920 class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { 921 using OpRewritePattern::OpRewritePattern; 922 923 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 924 PatternRewriter &rewriter) const override { 925 // Must be a scalar write. 926 auto vecType = xferOp.getVectorType(); 927 if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; })) 928 return failure(); 929 // Mask not supported. 930 if (xferOp.getMask()) 931 return failure(); 932 // Map not supported. 933 if (!xferOp.getPermutationMap().isMinorIdentity()) 934 return failure(); 935 // Only float and integer element types are supported. 936 Value scalar; 937 if (vecType.getRank() == 0) { 938 // vector.extract does not support vector<f32> etc., so use 939 // vector.extractelement instead. 940 scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(), 941 xferOp.getVector()); 942 } else { 943 SmallVector<int64_t> pos(vecType.getRank(), 0); 944 scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(), 945 xferOp.getVector(), pos); 946 } 947 // Construct a scalar store. 948 if (isa<MemRefType>(xferOp.getSource().getType())) { 949 rewriter.replaceOpWithNewOp<memref::StoreOp>( 950 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 951 } else { 952 rewriter.replaceOpWithNewOp<tensor::InsertOp>( 953 xferOp, scalar, xferOp.getSource(), xferOp.getIndices()); 954 } 955 return success(); 956 } 957 }; 958 959 } // namespace 960 961 void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, 962 Operation *rootOp) { 963 TransferOptimization opt(rewriter, rootOp); 964 // Run store to load forwarding first since it can expose more dead store 965 // opportunity. 966 rootOp->walk([&](vector::TransferReadOp read) { 967 if (isa<MemRefType>(read.getShapedType())) 968 opt.storeToLoadForwarding(read); 969 }); 970 opt.removeDeadOp(); 971 rootOp->walk([&](vector::TransferWriteOp write) { 972 if (isa<MemRefType>(write.getShapedType())) 973 opt.deadStoreOp(write); 974 }); 975 opt.removeDeadOp(); 976 } 977 978 void mlir::vector::populateScalarVectorTransferLoweringPatterns( 979 RewritePatternSet &patterns, PatternBenefit benefit, 980 bool allowMultipleUses) { 981 patterns.add<RewriteScalarExtractElementOfTransferRead, 982 RewriteScalarExtractOfTransferRead>(patterns.getContext(), 983 benefit, allowMultipleUses); 984 patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit); 985 } 986 987 void mlir::vector::populateVectorTransferDropUnitDimsPatterns( 988 RewritePatternSet &patterns, PatternBenefit benefit) { 989 patterns 990 .add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>( 991 patterns.getContext(), benefit); 992 populateShapeCastFoldingPatterns(patterns); 993 } 994 995 void mlir::vector::populateFlattenVectorTransferPatterns( 996 RewritePatternSet &patterns, unsigned targetVectorBitwidth, 997 PatternBenefit benefit) { 998 patterns.add<FlattenContiguousRowMajorTransferReadPattern, 999 FlattenContiguousRowMajorTransferWritePattern>( 1000 patterns.getContext(), targetVectorBitwidth, benefit); 1001 populateShapeCastFoldingPatterns(patterns, benefit); 1002 populateDropUnitDimWithShapeCastPatterns(patterns, benefit); 1003 } 1004