xref: /llvm-project/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- XeGPUFoldAliasOps.cpp - XeGPU alias ops folders ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
14 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #include "llvm/Support/Debug.h"
18 
19 namespace mlir {
20 namespace xegpu {
21 #define GEN_PASS_DEF_XEGPUFOLDALIASOPS
22 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
23 } // namespace xegpu
24 } // namespace mlir
25 
26 #define DEBUG_TYPE "xegpu-fold-alias-ops"
27 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
28 
29 using namespace mlir;
30 
31 namespace {
32 /// Merges subview operation with xegpu.create_nd_tdesc operation.
33 class XegpuCreateNdDescOpSubViewOpFolder final
34     : public OpRewritePattern<xegpu::CreateNdDescOp> {
35 public:
36   using OpRewritePattern<xegpu::CreateNdDescOp>::OpRewritePattern;
37 
38   LogicalResult matchAndRewrite(xegpu::CreateNdDescOp descOp,
39                                 PatternRewriter &rewriter) const override;
40 };
41 } // namespace
42 
43 LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite(
44     xegpu::CreateNdDescOp descOp, PatternRewriter &rewriter) const {
45   auto subViewOp = descOp.getSource().getDefiningOp<memref::SubViewOp>();
46 
47   if (!subViewOp)
48     return rewriter.notifyMatchFailure(descOp, "not a subview producer");
49   if (!subViewOp.hasUnitStride())
50     return rewriter.notifyMatchFailure(descOp, "requires unit strides");
51 
52   SmallVector<Value> resolvedOffsets;
53   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
54       rewriter, descOp.getLoc(), subViewOp.getMixedOffsets(),
55       subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
56       descOp.getMixedOffsets(), resolvedOffsets);
57 
58   rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
59       descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(),
60       getAsOpFoldResult(resolvedOffsets));
61 
62   return success();
63 }
64 
65 void xegpu::populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns) {
66   patterns.add<XegpuCreateNdDescOpSubViewOpFolder>(patterns.getContext());
67 }
68 
69 namespace {
70 
71 struct XeGPUFoldAliasOpsPass final
72     : public xegpu::impl::XeGPUFoldAliasOpsBase<XeGPUFoldAliasOpsPass> {
73   void runOnOperation() override;
74 };
75 
76 } // namespace
77 
78 void XeGPUFoldAliasOpsPass::runOnOperation() {
79   RewritePatternSet patterns(&getContext());
80   xegpu::populateXeGPUFoldAliasOpsPatterns(patterns);
81   (void)applyPatternsGreedily(getOperation(), std::move(patterns));
82 }
83