//===- XeGPUFoldAliasOps.cpp - XeGPU alias ops folders ----------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/Transforms/Passes.h" #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" namespace mlir { namespace xegpu { #define GEN_PASS_DEF_XEGPUFOLDALIASOPS #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" } // namespace xegpu } // namespace mlir #define DEBUG_TYPE "xegpu-fold-alias-ops" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") using namespace mlir; namespace { /// Merges subview operation with xegpu.create_nd_tdesc operation. class XegpuCreateNdDescOpSubViewOpFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(xegpu::CreateNdDescOp descOp, PatternRewriter &rewriter) const override; }; } // namespace LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite( xegpu::CreateNdDescOp descOp, PatternRewriter &rewriter) const { auto subViewOp = descOp.getSource().getDefiningOp(); if (!subViewOp) return rewriter.notifyMatchFailure(descOp, "not a subview producer"); if (!subViewOp.hasUnitStride()) return rewriter.notifyMatchFailure(descOp, "requires unit strides"); SmallVector resolvedOffsets; affine::resolveIndicesIntoOpWithOffsetsAndStrides( rewriter, descOp.getLoc(), subViewOp.getMixedOffsets(), subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), descOp.getMixedOffsets(), resolvedOffsets); rewriter.replaceOpWithNewOp( descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(), getAsOpFoldResult(resolvedOffsets)); return success(); } void xegpu::populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } namespace { struct XeGPUFoldAliasOpsPass final : public xegpu::impl::XeGPUFoldAliasOpsBase { void runOnOperation() override; }; } // namespace void XeGPUFoldAliasOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); xegpu::populateXeGPUFoldAliasOpsPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); }