1b334664fSAnton Lydike //===- MPIOps.cpp - MPI dialect ops implementation ------------------------===// 2b334664fSAnton Lydike // 3b334664fSAnton Lydike // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b334664fSAnton Lydike // See https://llvm.org/LICENSE.txt for license information. 5b334664fSAnton Lydike // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b334664fSAnton Lydike // 7b334664fSAnton Lydike //===----------------------------------------------------------------------===// 8b334664fSAnton Lydike 9b334664fSAnton Lydike #include "mlir/Dialect/MPI/IR/MPI.h" 10*79eb406aSFrank Schlimbach #include "mlir/Dialect/MemRef/IR/MemRef.h" 11b334664fSAnton Lydike #include "mlir/IR/Builders.h" 12b334664fSAnton Lydike #include "mlir/IR/BuiltinAttributes.h" 13*79eb406aSFrank Schlimbach #include "mlir/IR/PatternMatch.h" 14b334664fSAnton Lydike 15b334664fSAnton Lydike using namespace mlir; 16b334664fSAnton Lydike using namespace mlir::mpi; 17b334664fSAnton Lydike 18*79eb406aSFrank Schlimbach namespace { 19*79eb406aSFrank Schlimbach 20*79eb406aSFrank Schlimbach // If input memref has dynamic shape and is a cast and if the cast's input has 21*79eb406aSFrank Schlimbach // static shape, fold the cast's static input into the given operation. 22*79eb406aSFrank Schlimbach template <typename OpT> 23*79eb406aSFrank Schlimbach struct FoldCast final : public mlir::OpRewritePattern<OpT> { 24*79eb406aSFrank Schlimbach using mlir::OpRewritePattern<OpT>::OpRewritePattern; 25*79eb406aSFrank Schlimbach 26*79eb406aSFrank Schlimbach LogicalResult matchAndRewrite(OpT op, 27*79eb406aSFrank Schlimbach mlir::PatternRewriter &b) const override { 28*79eb406aSFrank Schlimbach auto mRef = op.getRef(); 29*79eb406aSFrank Schlimbach if (mRef.getType().hasStaticShape()) { 30*79eb406aSFrank Schlimbach return mlir::failure(); 31*79eb406aSFrank Schlimbach } 32*79eb406aSFrank Schlimbach auto defOp = mRef.getDefiningOp(); 33*79eb406aSFrank Schlimbach if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) { 34*79eb406aSFrank Schlimbach return mlir::failure(); 35*79eb406aSFrank Schlimbach } 36*79eb406aSFrank Schlimbach auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource(); 37*79eb406aSFrank Schlimbach if (!src.getType().hasStaticShape()) { 38*79eb406aSFrank Schlimbach return mlir::failure(); 39*79eb406aSFrank Schlimbach } 40*79eb406aSFrank Schlimbach op.getRefMutable().assign(src); 41*79eb406aSFrank Schlimbach return mlir::success(); 42*79eb406aSFrank Schlimbach } 43*79eb406aSFrank Schlimbach }; 44*79eb406aSFrank Schlimbach } // namespace 45*79eb406aSFrank Schlimbach 46*79eb406aSFrank Schlimbach void mlir::mpi::SendOp::getCanonicalizationPatterns( 47*79eb406aSFrank Schlimbach mlir::RewritePatternSet &results, mlir::MLIRContext *context) { 48*79eb406aSFrank Schlimbach results.add<FoldCast<mlir::mpi::SendOp>>(context); 49*79eb406aSFrank Schlimbach } 50*79eb406aSFrank Schlimbach 51*79eb406aSFrank Schlimbach void mlir::mpi::RecvOp::getCanonicalizationPatterns( 52*79eb406aSFrank Schlimbach mlir::RewritePatternSet &results, mlir::MLIRContext *context) { 53*79eb406aSFrank Schlimbach results.add<FoldCast<mlir::mpi::RecvOp>>(context); 54*79eb406aSFrank Schlimbach } 55*79eb406aSFrank Schlimbach 56b334664fSAnton Lydike //===----------------------------------------------------------------------===// 57b334664fSAnton Lydike // TableGen'd op method definitions 58b334664fSAnton Lydike //===----------------------------------------------------------------------===// 59b334664fSAnton Lydike 60b334664fSAnton Lydike #define GET_OP_CLASSES 61b334664fSAnton Lydike #include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc" 62