1f4c0c40fSAdam Siemieniuk //===- XeGPUFoldAliasOps.cpp - XeGPU alias ops folders ----------*- C++ -*-===// 2f4c0c40fSAdam Siemieniuk // 3f4c0c40fSAdam Siemieniuk // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4f4c0c40fSAdam Siemieniuk // See https://llvm.org/LICENSE.txt for license information. 5f4c0c40fSAdam Siemieniuk // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6f4c0c40fSAdam Siemieniuk // 7f4c0c40fSAdam Siemieniuk //===----------------------------------------------------------------------===// 8f4c0c40fSAdam Siemieniuk 9f4c0c40fSAdam Siemieniuk #include "mlir/Dialect/XeGPU/Transforms/Passes.h" 10f4c0c40fSAdam Siemieniuk 11f4c0c40fSAdam Siemieniuk #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" 12f4c0c40fSAdam Siemieniuk #include "mlir/Dialect/MemRef/IR/MemRef.h" 13f4c0c40fSAdam Siemieniuk #include "mlir/Dialect/XeGPU/IR/XeGPU.h" 14f4c0c40fSAdam Siemieniuk #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" 15f4c0c40fSAdam Siemieniuk #include "mlir/Pass/Pass.h" 16f4c0c40fSAdam Siemieniuk #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 17f4c0c40fSAdam Siemieniuk #include "llvm/Support/Debug.h" 18f4c0c40fSAdam Siemieniuk 19f4c0c40fSAdam Siemieniuk namespace mlir { 20f4c0c40fSAdam Siemieniuk namespace xegpu { 21f4c0c40fSAdam Siemieniuk #define GEN_PASS_DEF_XEGPUFOLDALIASOPS 22f4c0c40fSAdam Siemieniuk #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" 23f4c0c40fSAdam Siemieniuk } // namespace xegpu 24f4c0c40fSAdam Siemieniuk } // namespace mlir 25f4c0c40fSAdam Siemieniuk 26f4c0c40fSAdam Siemieniuk #define DEBUG_TYPE "xegpu-fold-alias-ops" 27f4c0c40fSAdam Siemieniuk #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 28f4c0c40fSAdam Siemieniuk 29f4c0c40fSAdam Siemieniuk using namespace mlir; 30f4c0c40fSAdam Siemieniuk 31f4c0c40fSAdam Siemieniuk namespace { 32f4c0c40fSAdam Siemieniuk /// Merges subview operation with xegpu.create_nd_tdesc operation. 33f4c0c40fSAdam Siemieniuk class XegpuCreateNdDescOpSubViewOpFolder final 34f4c0c40fSAdam Siemieniuk : public OpRewritePattern<xegpu::CreateNdDescOp> { 35f4c0c40fSAdam Siemieniuk public: 36f4c0c40fSAdam Siemieniuk using OpRewritePattern<xegpu::CreateNdDescOp>::OpRewritePattern; 37f4c0c40fSAdam Siemieniuk 38f4c0c40fSAdam Siemieniuk LogicalResult matchAndRewrite(xegpu::CreateNdDescOp descOp, 39f4c0c40fSAdam Siemieniuk PatternRewriter &rewriter) const override; 40f4c0c40fSAdam Siemieniuk }; 41f4c0c40fSAdam Siemieniuk } // namespace 42f4c0c40fSAdam Siemieniuk 43f4c0c40fSAdam Siemieniuk LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite( 44f4c0c40fSAdam Siemieniuk xegpu::CreateNdDescOp descOp, PatternRewriter &rewriter) const { 45f4c0c40fSAdam Siemieniuk auto subViewOp = descOp.getSource().getDefiningOp<memref::SubViewOp>(); 46f4c0c40fSAdam Siemieniuk 47f4c0c40fSAdam Siemieniuk if (!subViewOp) 48f4c0c40fSAdam Siemieniuk return rewriter.notifyMatchFailure(descOp, "not a subview producer"); 49f4c0c40fSAdam Siemieniuk if (!subViewOp.hasUnitStride()) 50f4c0c40fSAdam Siemieniuk return rewriter.notifyMatchFailure(descOp, "requires unit strides"); 51f4c0c40fSAdam Siemieniuk 52f4c0c40fSAdam Siemieniuk SmallVector<Value> resolvedOffsets; 53f4c0c40fSAdam Siemieniuk affine::resolveIndicesIntoOpWithOffsetsAndStrides( 54f4c0c40fSAdam Siemieniuk rewriter, descOp.getLoc(), subViewOp.getMixedOffsets(), 55f4c0c40fSAdam Siemieniuk subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), 56f4c0c40fSAdam Siemieniuk descOp.getMixedOffsets(), resolvedOffsets); 57f4c0c40fSAdam Siemieniuk 58f4c0c40fSAdam Siemieniuk rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>( 59f4c0c40fSAdam Siemieniuk descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(), 60f4c0c40fSAdam Siemieniuk getAsOpFoldResult(resolvedOffsets)); 61f4c0c40fSAdam Siemieniuk 62f4c0c40fSAdam Siemieniuk return success(); 63f4c0c40fSAdam Siemieniuk } 64f4c0c40fSAdam Siemieniuk 65f4c0c40fSAdam Siemieniuk void xegpu::populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns) { 66f4c0c40fSAdam Siemieniuk patterns.add<XegpuCreateNdDescOpSubViewOpFolder>(patterns.getContext()); 67f4c0c40fSAdam Siemieniuk } 68f4c0c40fSAdam Siemieniuk 69f4c0c40fSAdam Siemieniuk namespace { 70f4c0c40fSAdam Siemieniuk 71f4c0c40fSAdam Siemieniuk struct XeGPUFoldAliasOpsPass final 72f4c0c40fSAdam Siemieniuk : public xegpu::impl::XeGPUFoldAliasOpsBase<XeGPUFoldAliasOpsPass> { 73f4c0c40fSAdam Siemieniuk void runOnOperation() override; 74f4c0c40fSAdam Siemieniuk }; 75f4c0c40fSAdam Siemieniuk 76f4c0c40fSAdam Siemieniuk } // namespace 77f4c0c40fSAdam Siemieniuk 78f4c0c40fSAdam Siemieniuk void XeGPUFoldAliasOpsPass::runOnOperation() { 79f4c0c40fSAdam Siemieniuk RewritePatternSet patterns(&getContext()); 80f4c0c40fSAdam Siemieniuk xegpu::populateXeGPUFoldAliasOpsPatterns(patterns); 81*09dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 82f4c0c40fSAdam Siemieniuk } 83