1 //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass folds loading/storing from/to subview ops into 10 // loading/storing from/to the original memref. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 19 #include "mlir/Dialect/MemRef/IR/MemRef.h" 20 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 21 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 22 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 23 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 24 #include "mlir/Dialect/Utils/IndexingUtils.h" 25 #include "mlir/Dialect/Vector/IR/VectorOps.h" 26 #include "mlir/IR/AffineMap.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/SmallBitVector.h" 30 #include "llvm/ADT/TypeSwitch.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "fold-memref-alias-ops" 34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 35 36 namespace mlir { 37 namespace memref { 38 #define GEN_PASS_DEF_FOLDMEMREFALIASOPS 39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 40 } // namespace memref 41 } // namespace mlir 42 43 using namespace mlir; 44 45 //===----------------------------------------------------------------------===// 46 // Utility functions 47 //===----------------------------------------------------------------------===// 48 49 /// Given the 'indices' of a load/store operation where the memref is a result 50 /// of a expand_shape op, returns the indices w.r.t to the source memref of the 51 /// expand_shape op. For example 52 /// 53 /// %0 = ... : memref<12x42xf32> 54 /// %1 = memref.expand_shape %0 [[0, 1], [2]] 55 /// : memref<12x42xf32> into memref<2x6x42xf32> 56 /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32 57 /// 58 /// could be folded into 59 /// 60 /// %2 = load %0[6 * i1 + i2, %i3] : 61 /// memref<12x42xf32> 62 static LogicalResult 63 resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, 64 memref::ExpandShapeOp expandShapeOp, 65 ValueRange indices, 66 SmallVectorImpl<Value> &sourceIndices) { 67 // Record the rewriter context for constructing ops later. 68 MLIRContext *ctx = rewriter.getContext(); 69 70 // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`. 71 // This is done for the purpose of inferring the output shape via 72 // `inferExpandOutputShape` which will in turn be used for suffix product 73 // calculation later. 74 SmallVector<OpFoldResult> srcShape; 75 MemRefType srcType = expandShapeOp.getSrcType(); 76 77 for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) { 78 if (srcType.isDynamicDim(i)) { 79 srcShape.push_back( 80 rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i) 81 .getResult()); 82 } else { 83 srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i])); 84 } 85 } 86 87 auto outputShape = inferExpandShapeOutputShape( 88 rewriter, loc, expandShapeOp.getResultType(), 89 expandShapeOp.getReassociationIndices(), srcShape); 90 if (!outputShape.has_value()) 91 return failure(); 92 93 // Traverse all reassociation groups to determine the appropriate indices 94 // corresponding to each one of them post op folding. 95 for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) { 96 assert(!groups.empty() && "association indices groups cannot be empty"); 97 // Flag to indicate the presence of dynamic dimensions in current 98 // reassociation group. 99 int64_t groupSize = groups.size(); 100 101 // Group output dimensions utilized in this reassociation group for suffix 102 // product calculation. 103 SmallVector<OpFoldResult> sizesVal(groupSize); 104 for (int64_t i = 0; i < groupSize; ++i) { 105 sizesVal[i] = (*outputShape)[groups[i]]; 106 } 107 108 // Calculate suffix product of relevant output dimension sizes. 109 SmallVector<OpFoldResult> suffixProduct = 110 memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal); 111 112 // Create affine expression variables for dimensions and symbols in the 113 // newly constructed affine map. 114 SmallVector<AffineExpr> dims(groupSize), symbols(groupSize); 115 bindDimsList<AffineExpr>(ctx, dims); 116 bindSymbolsList<AffineExpr>(ctx, symbols); 117 118 // Linearize binded dimensions and symbols to construct the resultant 119 // affine expression for this indice. 120 AffineExpr srcIndexExpr = linearize(ctx, dims, symbols); 121 122 // Record the load index corresponding to each dimension in the 123 // reassociation group. These are later supplied as operands to the affine 124 // map used for calulating relevant index post op folding. 125 SmallVector<OpFoldResult> dynamicIndices(groupSize); 126 for (int64_t i = 0; i < groupSize; i++) 127 dynamicIndices[i] = indices[groups[i]]; 128 129 // Supply suffix product results followed by load op indices as operands 130 // to the map. 131 SmallVector<OpFoldResult> mapOperands; 132 llvm::append_range(mapOperands, suffixProduct); 133 llvm::append_range(mapOperands, dynamicIndices); 134 135 // Creating maximally folded and composed affine.apply composes better 136 // with other transformations without interleaving canonicalization 137 // passes. 138 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 139 rewriter, loc, 140 AffineMap::get(/*numDims=*/groupSize, 141 /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr), 142 mapOperands); 143 144 // Push index value in the op post folding corresponding to this 145 // reassociation group. 146 sourceIndices.push_back( 147 getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); 148 } 149 return success(); 150 } 151 152 /// Given the 'indices' of a load/store operation where the memref is a result 153 /// of a collapse_shape op, returns the indices w.r.t to the source memref of 154 /// the collapse_shape op. For example 155 /// 156 /// %0 = ... : memref<2x6x42xf32> 157 /// %1 = memref.collapse_shape %0 [[0, 1], [2]] 158 /// : memref<2x6x42xf32> into memref<12x42xf32> 159 /// %2 = load %1[%i1, %i2] : memref<12x42xf32> 160 /// 161 /// could be folded into 162 /// 163 /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] : 164 /// memref<2x6x42xf32> 165 static LogicalResult 166 resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, 167 memref::CollapseShapeOp collapseShapeOp, 168 ValueRange indices, 169 SmallVectorImpl<Value> &sourceIndices) { 170 int64_t cnt = 0; 171 SmallVector<Value> tmp(indices.size()); 172 SmallVector<OpFoldResult> dynamicIndices; 173 for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) { 174 assert(!groups.empty() && "association indices groups cannot be empty"); 175 dynamicIndices.push_back(indices[cnt++]); 176 int64_t groupSize = groups.size(); 177 178 // Calculate suffix product for all collapse op source dimension sizes 179 // except the most major one of each group. 180 // We allow the most major source dimension to be dynamic but enforce all 181 // others to be known statically. 182 SmallVector<int64_t> sizes(groupSize, 1); 183 for (int64_t i = 1; i < groupSize; ++i) { 184 sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]); 185 if (sizes[i] == ShapedType::kDynamic) 186 return failure(); 187 } 188 SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes); 189 190 // Derive the index values along all dimensions of the source corresponding 191 // to the index wrt to collapsed shape op output. 192 auto d0 = rewriter.getAffineDimExpr(0); 193 SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct); 194 195 // Construct the AffineApplyOp for each delinearizingExpr. 196 for (int64_t i = 0; i < groupSize; i++) { 197 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 198 rewriter, loc, 199 AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, 200 delinearizingExprs[i]), 201 dynamicIndices); 202 sourceIndices.push_back( 203 getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); 204 } 205 dynamicIndices.clear(); 206 } 207 if (collapseShapeOp.getReassociationIndices().empty()) { 208 auto zeroAffineMap = rewriter.getConstantAffineMap(0); 209 int64_t srcRank = 210 cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank(); 211 for (int64_t i = 0; i < srcRank; i++) { 212 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 213 rewriter, loc, zeroAffineMap, dynamicIndices); 214 sourceIndices.push_back( 215 getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); 216 } 217 } 218 return success(); 219 } 220 221 /// Helpers to access the memref operand for each op. 222 template <typename LoadOrStoreOpTy> 223 static Value getMemRefOperand(LoadOrStoreOpTy op) { 224 return op.getMemref(); 225 } 226 227 static Value getMemRefOperand(vector::TransferReadOp op) { 228 return op.getSource(); 229 } 230 231 static Value getMemRefOperand(nvgpu::LdMatrixOp op) { 232 return op.getSrcMemref(); 233 } 234 235 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); } 236 237 static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); } 238 239 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); } 240 241 static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); } 242 243 static Value getMemRefOperand(vector::TransferWriteOp op) { 244 return op.getSource(); 245 } 246 247 static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) { 248 return op.getSrcMemref(); 249 } 250 251 static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) { 252 return op.getDstMemref(); 253 } 254 255 //===----------------------------------------------------------------------===// 256 // Patterns 257 //===----------------------------------------------------------------------===// 258 259 namespace { 260 /// Merges subview operation with load/transferRead operation. 261 template <typename OpTy> 262 class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { 263 public: 264 using OpRewritePattern<OpTy>::OpRewritePattern; 265 266 LogicalResult matchAndRewrite(OpTy loadOp, 267 PatternRewriter &rewriter) const override; 268 }; 269 270 /// Merges expand_shape operation with load/transferRead operation. 271 template <typename OpTy> 272 class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { 273 public: 274 using OpRewritePattern<OpTy>::OpRewritePattern; 275 276 LogicalResult matchAndRewrite(OpTy loadOp, 277 PatternRewriter &rewriter) const override; 278 }; 279 280 /// Merges collapse_shape operation with load/transferRead operation. 281 template <typename OpTy> 282 class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { 283 public: 284 using OpRewritePattern<OpTy>::OpRewritePattern; 285 286 LogicalResult matchAndRewrite(OpTy loadOp, 287 PatternRewriter &rewriter) const override; 288 }; 289 290 /// Merges subview operation with store/transferWriteOp operation. 291 template <typename OpTy> 292 class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> { 293 public: 294 using OpRewritePattern<OpTy>::OpRewritePattern; 295 296 LogicalResult matchAndRewrite(OpTy storeOp, 297 PatternRewriter &rewriter) const override; 298 }; 299 300 /// Merges expand_shape operation with store/transferWriteOp operation. 301 template <typename OpTy> 302 class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> { 303 public: 304 using OpRewritePattern<OpTy>::OpRewritePattern; 305 306 LogicalResult matchAndRewrite(OpTy storeOp, 307 PatternRewriter &rewriter) const override; 308 }; 309 310 /// Merges collapse_shape operation with store/transferWriteOp operation. 311 template <typename OpTy> 312 class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> { 313 public: 314 using OpRewritePattern<OpTy>::OpRewritePattern; 315 316 LogicalResult matchAndRewrite(OpTy storeOp, 317 PatternRewriter &rewriter) const override; 318 }; 319 320 /// Folds subview(subview(x)) to a single subview(x). 321 class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> { 322 public: 323 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern; 324 325 LogicalResult matchAndRewrite(memref::SubViewOp subView, 326 PatternRewriter &rewriter) const override { 327 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>(); 328 if (!srcSubView) 329 return failure(); 330 331 // TODO: relax unit stride assumption. 332 if (!subView.hasUnitStride()) { 333 return rewriter.notifyMatchFailure(subView, "requires unit strides"); 334 } 335 if (!srcSubView.hasUnitStride()) { 336 return rewriter.notifyMatchFailure(srcSubView, "requires unit strides"); 337 } 338 339 // Resolve sizes according to dropped dims. 340 SmallVector<OpFoldResult> resolvedSizes; 341 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims(); 342 affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(), 343 subView.getMixedSizes(), srcDroppedDims, 344 resolvedSizes); 345 346 // Resolve offsets according to source offsets and strides. 347 SmallVector<Value> resolvedOffsets; 348 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 349 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(), 350 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(), 351 resolvedOffsets); 352 353 // Replace original op. 354 rewriter.replaceOpWithNewOp<memref::SubViewOp>( 355 subView, subView.getType(), srcSubView.getSource(), 356 getAsOpFoldResult(resolvedOffsets), resolvedSizes, 357 srcSubView.getMixedStrides()); 358 359 return success(); 360 } 361 }; 362 363 /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern 364 /// is folds subview on src and dst memref of the copy. 365 class NVGPUAsyncCopyOpSubViewOpFolder final 366 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> { 367 public: 368 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern; 369 370 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp, 371 PatternRewriter &rewriter) const override; 372 }; 373 } // namespace 374 375 static SmallVector<Value> 376 calculateExpandedAccessIndices(AffineMap affineMap, 377 const SmallVector<Value> &indices, Location loc, 378 PatternRewriter &rewriter) { 379 SmallVector<OpFoldResult> indicesOfr(llvm::to_vector( 380 llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; }))); 381 SmallVector<Value> expandedIndices; 382 for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) { 383 OpFoldResult ofr = affine::makeComposedFoldedAffineApply( 384 rewriter, loc, affineMap.getSubMap({i}), indicesOfr); 385 expandedIndices.push_back( 386 getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); 387 } 388 return expandedIndices; 389 } 390 391 template <typename XferOp> 392 static LogicalResult 393 preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, 394 memref::SubViewOp subviewOp) { 395 static_assert( 396 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value, 397 "must be a vector transfer op"); 398 if (xferOp.hasOutOfBoundsDim()) 399 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim"); 400 if (!subviewOp.hasUnitStride()) { 401 return rewriter.notifyMatchFailure( 402 xferOp, "non-1 stride subview, need to track strides in folded memref"); 403 } 404 return success(); 405 } 406 407 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, 408 Operation *op, 409 memref::SubViewOp subviewOp) { 410 return success(); 411 } 412 413 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, 414 vector::TransferReadOp readOp, 415 memref::SubViewOp subviewOp) { 416 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp); 417 } 418 419 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, 420 vector::TransferWriteOp writeOp, 421 memref::SubViewOp subviewOp) { 422 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp); 423 } 424 425 template <typename OpTy> 426 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite( 427 OpTy loadOp, PatternRewriter &rewriter) const { 428 auto subViewOp = 429 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 430 431 if (!subViewOp) 432 return rewriter.notifyMatchFailure(loadOp, "not a subview producer"); 433 434 LogicalResult preconditionResult = 435 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp); 436 if (failed(preconditionResult)) 437 return preconditionResult; 438 439 SmallVector<Value> indices(loadOp.getIndices().begin(), 440 loadOp.getIndices().end()); 441 // For affine ops, we need to apply the map to get the operands to get the 442 // "actual" indices. 443 if (auto affineLoadOp = 444 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { 445 AffineMap affineMap = affineLoadOp.getAffineMap(); 446 auto expandedIndices = calculateExpandedAccessIndices( 447 affineMap, indices, loadOp.getLoc(), rewriter); 448 indices.assign(expandedIndices.begin(), expandedIndices.end()); 449 } 450 SmallVector<Value> sourceIndices; 451 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 452 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(), 453 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, 454 sourceIndices); 455 456 llvm::TypeSwitch<Operation *, void>(loadOp) 457 .Case([&](affine::AffineLoadOp op) { 458 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( 459 loadOp, subViewOp.getSource(), sourceIndices); 460 }) 461 .Case([&](memref::LoadOp op) { 462 rewriter.replaceOpWithNewOp<memref::LoadOp>( 463 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal()); 464 }) 465 .Case([&](vector::LoadOp op) { 466 rewriter.replaceOpWithNewOp<vector::LoadOp>( 467 op, op.getType(), subViewOp.getSource(), sourceIndices); 468 }) 469 .Case([&](vector::MaskedLoadOp op) { 470 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( 471 op, op.getType(), subViewOp.getSource(), sourceIndices, 472 op.getMask(), op.getPassThru()); 473 }) 474 .Case([&](vector::TransferReadOp op) { 475 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 476 op, op.getVectorType(), subViewOp.getSource(), sourceIndices, 477 AffineMapAttr::get(expandDimsToRank( 478 op.getPermutationMap(), subViewOp.getSourceType().getRank(), 479 subViewOp.getDroppedDims())), 480 op.getPadding(), op.getMask(), op.getInBoundsAttr()); 481 }) 482 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) { 483 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>( 484 op, op.getType(), subViewOp.getSource(), sourceIndices, 485 op.getLeadDimension(), op.getTransposeAttr()); 486 }) 487 .Case([&](nvgpu::LdMatrixOp op) { 488 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>( 489 op, op.getType(), subViewOp.getSource(), sourceIndices, 490 op.getTranspose(), op.getNumTiles()); 491 }) 492 .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 493 return success(); 494 } 495 496 template <typename OpTy> 497 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( 498 OpTy loadOp, PatternRewriter &rewriter) const { 499 auto expandShapeOp = 500 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>(); 501 502 if (!expandShapeOp) 503 return failure(); 504 505 SmallVector<Value> indices(loadOp.getIndices().begin(), 506 loadOp.getIndices().end()); 507 // For affine ops, we need to apply the map to get the operands to get the 508 // "actual" indices. 509 if (auto affineLoadOp = 510 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { 511 AffineMap affineMap = affineLoadOp.getAffineMap(); 512 auto expandedIndices = calculateExpandedAccessIndices( 513 affineMap, indices, loadOp.getLoc(), rewriter); 514 indices.assign(expandedIndices.begin(), expandedIndices.end()); 515 } 516 SmallVector<Value> sourceIndices; 517 if (failed(resolveSourceIndicesExpandShape( 518 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) 519 return failure(); 520 llvm::TypeSwitch<Operation *, void>(loadOp) 521 .Case([&](affine::AffineLoadOp op) { 522 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( 523 loadOp, expandShapeOp.getViewSource(), sourceIndices); 524 }) 525 .Case([&](memref::LoadOp op) { 526 rewriter.replaceOpWithNewOp<memref::LoadOp>( 527 loadOp, expandShapeOp.getViewSource(), sourceIndices, 528 op.getNontemporal()); 529 }) 530 .Case([&](vector::LoadOp op) { 531 rewriter.replaceOpWithNewOp<vector::LoadOp>( 532 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, 533 op.getNontemporal()); 534 }) 535 .Case([&](vector::MaskedLoadOp op) { 536 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( 537 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices, 538 op.getMask(), op.getPassThru()); 539 }) 540 .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 541 return success(); 542 } 543 544 template <typename OpTy> 545 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( 546 OpTy loadOp, PatternRewriter &rewriter) const { 547 auto collapseShapeOp = getMemRefOperand(loadOp) 548 .template getDefiningOp<memref::CollapseShapeOp>(); 549 550 if (!collapseShapeOp) 551 return failure(); 552 553 SmallVector<Value> indices(loadOp.getIndices().begin(), 554 loadOp.getIndices().end()); 555 // For affine ops, we need to apply the map to get the operands to get the 556 // "actual" indices. 557 if (auto affineLoadOp = 558 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) { 559 AffineMap affineMap = affineLoadOp.getAffineMap(); 560 auto expandedIndices = calculateExpandedAccessIndices( 561 affineMap, indices, loadOp.getLoc(), rewriter); 562 indices.assign(expandedIndices.begin(), expandedIndices.end()); 563 } 564 SmallVector<Value> sourceIndices; 565 if (failed(resolveSourceIndicesCollapseShape( 566 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) 567 return failure(); 568 llvm::TypeSwitch<Operation *, void>(loadOp) 569 .Case([&](affine::AffineLoadOp op) { 570 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( 571 loadOp, collapseShapeOp.getViewSource(), sourceIndices); 572 }) 573 .Case([&](memref::LoadOp op) { 574 rewriter.replaceOpWithNewOp<memref::LoadOp>( 575 loadOp, collapseShapeOp.getViewSource(), sourceIndices, 576 op.getNontemporal()); 577 }) 578 .Case([&](vector::LoadOp op) { 579 rewriter.replaceOpWithNewOp<vector::LoadOp>( 580 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, 581 op.getNontemporal()); 582 }) 583 .Case([&](vector::MaskedLoadOp op) { 584 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>( 585 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices, 586 op.getMask(), op.getPassThru()); 587 }) 588 .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 589 return success(); 590 } 591 592 template <typename OpTy> 593 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite( 594 OpTy storeOp, PatternRewriter &rewriter) const { 595 auto subViewOp = 596 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 597 598 if (!subViewOp) 599 return rewriter.notifyMatchFailure(storeOp, "not a subview producer"); 600 601 LogicalResult preconditionResult = 602 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp); 603 if (failed(preconditionResult)) 604 return preconditionResult; 605 606 SmallVector<Value> indices(storeOp.getIndices().begin(), 607 storeOp.getIndices().end()); 608 // For affine ops, we need to apply the map to get the operands to get the 609 // "actual" indices. 610 if (auto affineStoreOp = 611 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { 612 AffineMap affineMap = affineStoreOp.getAffineMap(); 613 auto expandedIndices = calculateExpandedAccessIndices( 614 affineMap, indices, storeOp.getLoc(), rewriter); 615 indices.assign(expandedIndices.begin(), expandedIndices.end()); 616 } 617 SmallVector<Value> sourceIndices; 618 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 619 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(), 620 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices, 621 sourceIndices); 622 623 llvm::TypeSwitch<Operation *, void>(storeOp) 624 .Case([&](affine::AffineStoreOp op) { 625 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( 626 op, op.getValue(), subViewOp.getSource(), sourceIndices); 627 }) 628 .Case([&](memref::StoreOp op) { 629 rewriter.replaceOpWithNewOp<memref::StoreOp>( 630 op, op.getValue(), subViewOp.getSource(), sourceIndices, 631 op.getNontemporal()); 632 }) 633 .Case([&](vector::TransferWriteOp op) { 634 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 635 op, op.getValue(), subViewOp.getSource(), sourceIndices, 636 AffineMapAttr::get(expandDimsToRank( 637 op.getPermutationMap(), subViewOp.getSourceType().getRank(), 638 subViewOp.getDroppedDims())), 639 op.getMask(), op.getInBoundsAttr()); 640 }) 641 .Case([&](vector::StoreOp op) { 642 rewriter.replaceOpWithNewOp<vector::StoreOp>( 643 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices); 644 }) 645 .Case([&](vector::MaskedStoreOp op) { 646 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 647 op, subViewOp.getSource(), sourceIndices, op.getMask(), 648 op.getValueToStore()); 649 }) 650 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { 651 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>( 652 op, op.getSrc(), subViewOp.getSource(), sourceIndices, 653 op.getLeadDimension(), op.getTransposeAttr()); 654 }) 655 .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 656 return success(); 657 } 658 659 template <typename OpTy> 660 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite( 661 OpTy storeOp, PatternRewriter &rewriter) const { 662 auto expandShapeOp = 663 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>(); 664 665 if (!expandShapeOp) 666 return failure(); 667 668 SmallVector<Value> indices(storeOp.getIndices().begin(), 669 storeOp.getIndices().end()); 670 // For affine ops, we need to apply the map to get the operands to get the 671 // "actual" indices. 672 if (auto affineStoreOp = 673 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { 674 AffineMap affineMap = affineStoreOp.getAffineMap(); 675 auto expandedIndices = calculateExpandedAccessIndices( 676 affineMap, indices, storeOp.getLoc(), rewriter); 677 indices.assign(expandedIndices.begin(), expandedIndices.end()); 678 } 679 SmallVector<Value> sourceIndices; 680 if (failed(resolveSourceIndicesExpandShape( 681 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices))) 682 return failure(); 683 llvm::TypeSwitch<Operation *, void>(storeOp) 684 .Case([&](affine::AffineStoreOp op) { 685 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( 686 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(), 687 sourceIndices); 688 }) 689 .Case([&](memref::StoreOp op) { 690 rewriter.replaceOpWithNewOp<memref::StoreOp>( 691 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(), 692 sourceIndices, op.getNontemporal()); 693 }) 694 .Case([&](vector::StoreOp op) { 695 rewriter.replaceOpWithNewOp<vector::StoreOp>( 696 op, op.getValueToStore(), expandShapeOp.getViewSource(), 697 sourceIndices, op.getNontemporal()); 698 }) 699 .Case([&](vector::MaskedStoreOp op) { 700 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 701 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(), 702 op.getValueToStore()); 703 }) 704 .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 705 return success(); 706 } 707 708 template <typename OpTy> 709 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite( 710 OpTy storeOp, PatternRewriter &rewriter) const { 711 auto collapseShapeOp = getMemRefOperand(storeOp) 712 .template getDefiningOp<memref::CollapseShapeOp>(); 713 714 if (!collapseShapeOp) 715 return failure(); 716 717 SmallVector<Value> indices(storeOp.getIndices().begin(), 718 storeOp.getIndices().end()); 719 // For affine ops, we need to apply the map to get the operands to get the 720 // "actual" indices. 721 if (auto affineStoreOp = 722 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) { 723 AffineMap affineMap = affineStoreOp.getAffineMap(); 724 auto expandedIndices = calculateExpandedAccessIndices( 725 affineMap, indices, storeOp.getLoc(), rewriter); 726 indices.assign(expandedIndices.begin(), expandedIndices.end()); 727 } 728 SmallVector<Value> sourceIndices; 729 if (failed(resolveSourceIndicesCollapseShape( 730 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices))) 731 return failure(); 732 llvm::TypeSwitch<Operation *, void>(storeOp) 733 .Case([&](affine::AffineStoreOp op) { 734 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( 735 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(), 736 sourceIndices); 737 }) 738 .Case([&](memref::StoreOp op) { 739 rewriter.replaceOpWithNewOp<memref::StoreOp>( 740 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(), 741 sourceIndices, op.getNontemporal()); 742 }) 743 .Case([&](vector::StoreOp op) { 744 rewriter.replaceOpWithNewOp<vector::StoreOp>( 745 op, op.getValueToStore(), collapseShapeOp.getViewSource(), 746 sourceIndices, op.getNontemporal()); 747 }) 748 .Case([&](vector::MaskedStoreOp op) { 749 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 750 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(), 751 op.getValueToStore()); 752 }) 753 .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); 754 return success(); 755 } 756 757 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite( 758 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const { 759 760 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n"); 761 762 auto srcSubViewOp = 763 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>(); 764 auto dstSubViewOp = 765 copyOp.getDst().template getDefiningOp<memref::SubViewOp>(); 766 767 if (!(srcSubViewOp || dstSubViewOp)) 768 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for " 769 "source or destination"); 770 771 // If the source is a subview, we need to resolve the indices. 772 SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(), 773 copyOp.getSrcIndices().end()); 774 SmallVector<Value> foldedSrcIndices(srcindices); 775 776 if (srcSubViewOp) { 777 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n"); 778 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 779 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(), 780 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(), 781 srcindices, foldedSrcIndices); 782 } 783 784 // If the destination is a subview, we need to resolve the indices. 785 SmallVector<Value> dstindices(copyOp.getDstIndices().begin(), 786 copyOp.getDstIndices().end()); 787 SmallVector<Value> foldedDstIndices(dstindices); 788 789 if (dstSubViewOp) { 790 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n"); 791 affine::resolveIndicesIntoOpWithOffsetsAndStrides( 792 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(), 793 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(), 794 dstindices, foldedDstIndices); 795 } 796 797 // Replace the copy op with a new copy op that uses the source and destination 798 // of the subview. 799 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>( 800 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()), 801 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()), 802 foldedDstIndices, 803 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()), 804 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(), 805 copyOp.getBypassL1Attr()); 806 807 return success(); 808 } 809 810 void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { 811 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>, 812 LoadOpOfSubViewOpFolder<memref::LoadOp>, 813 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>, 814 LoadOpOfSubViewOpFolder<vector::LoadOp>, 815 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>, 816 LoadOpOfSubViewOpFolder<vector::TransferReadOp>, 817 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>, 818 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>, 819 StoreOpOfSubViewOpFolder<memref::StoreOp>, 820 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>, 821 StoreOpOfSubViewOpFolder<vector::StoreOp>, 822 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>, 823 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>, 824 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>, 825 LoadOpOfExpandShapeOpFolder<memref::LoadOp>, 826 LoadOpOfExpandShapeOpFolder<vector::LoadOp>, 827 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>, 828 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>, 829 StoreOpOfExpandShapeOpFolder<memref::StoreOp>, 830 StoreOpOfExpandShapeOpFolder<vector::StoreOp>, 831 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>, 832 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>, 833 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>, 834 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>, 835 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>, 836 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>, 837 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>, 838 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>, 839 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>, 840 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>( 841 patterns.getContext()); 842 } 843 844 //===----------------------------------------------------------------------===// 845 // Pass registration 846 //===----------------------------------------------------------------------===// 847 848 namespace { 849 850 struct FoldMemRefAliasOpsPass final 851 : public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> { 852 void runOnOperation() override; 853 }; 854 855 } // namespace 856 857 void FoldMemRefAliasOpsPass::runOnOperation() { 858 RewritePatternSet patterns(&getContext()); 859 memref::populateFoldMemRefAliasOpPatterns(patterns); 860 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 861 } 862 863 std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() { 864 return std::make_unique<FoldMemRefAliasOpsPass>(); 865 } 866