1 //===- MPIOps.cpp - MPI dialect ops implementation ------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MPI/IR/MPI.h" 10 #include "mlir/Dialect/MemRef/IR/MemRef.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/BuiltinAttributes.h" 13 #include "mlir/IR/PatternMatch.h" 14 15 using namespace mlir; 16 using namespace mlir::mpi; 17 18 namespace { 19 20 // If input memref has dynamic shape and is a cast and if the cast's input has 21 // static shape, fold the cast's static input into the given operation. 22 template <typename OpT> 23 struct FoldCast final : public mlir::OpRewritePattern<OpT> { 24 using mlir::OpRewritePattern<OpT>::OpRewritePattern; 25 26 LogicalResult matchAndRewrite(OpT op, 27 mlir::PatternRewriter &b) const override { 28 auto mRef = op.getRef(); 29 if (mRef.getType().hasStaticShape()) { 30 return mlir::failure(); 31 } 32 auto defOp = mRef.getDefiningOp(); 33 if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) { 34 return mlir::failure(); 35 } 36 auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource(); 37 if (!src.getType().hasStaticShape()) { 38 return mlir::failure(); 39 } 40 op.getRefMutable().assign(src); 41 return mlir::success(); 42 } 43 }; 44 } // namespace 45 46 void mlir::mpi::SendOp::getCanonicalizationPatterns( 47 mlir::RewritePatternSet &results, mlir::MLIRContext *context) { 48 results.add<FoldCast<mlir::mpi::SendOp>>(context); 49 } 50 51 void mlir::mpi::RecvOp::getCanonicalizationPatterns( 52 mlir::RewritePatternSet &results, mlir::MLIRContext *context) { 53 results.add<FoldCast<mlir::mpi::RecvOp>>(context); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // TableGen'd op method definitions 58 //===----------------------------------------------------------------------===// 59 60 #define GET_OP_CLASSES 61 #include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc" 62