xref: /llvm-project/mlir/lib/Dialect/MPI/IR/MPIOps.cpp (revision 79eb406a67fe08458548289da72cda18248a9313)
1b334664fSAnton Lydike //===- MPIOps.cpp - MPI dialect ops implementation ------------------------===//
2b334664fSAnton Lydike //
3b334664fSAnton Lydike // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b334664fSAnton Lydike // See https://llvm.org/LICENSE.txt for license information.
5b334664fSAnton Lydike // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b334664fSAnton Lydike //
7b334664fSAnton Lydike //===----------------------------------------------------------------------===//
8b334664fSAnton Lydike 
9b334664fSAnton Lydike #include "mlir/Dialect/MPI/IR/MPI.h"
10*79eb406aSFrank Schlimbach #include "mlir/Dialect/MemRef/IR/MemRef.h"
11b334664fSAnton Lydike #include "mlir/IR/Builders.h"
12b334664fSAnton Lydike #include "mlir/IR/BuiltinAttributes.h"
13*79eb406aSFrank Schlimbach #include "mlir/IR/PatternMatch.h"
14b334664fSAnton Lydike 
15b334664fSAnton Lydike using namespace mlir;
16b334664fSAnton Lydike using namespace mlir::mpi;
17b334664fSAnton Lydike 
18*79eb406aSFrank Schlimbach namespace {
19*79eb406aSFrank Schlimbach 
20*79eb406aSFrank Schlimbach // If input memref has dynamic shape and is a cast and if the cast's input has
21*79eb406aSFrank Schlimbach // static shape, fold the cast's static input into the given operation.
22*79eb406aSFrank Schlimbach template <typename OpT>
23*79eb406aSFrank Schlimbach struct FoldCast final : public mlir::OpRewritePattern<OpT> {
24*79eb406aSFrank Schlimbach   using mlir::OpRewritePattern<OpT>::OpRewritePattern;
25*79eb406aSFrank Schlimbach 
26*79eb406aSFrank Schlimbach   LogicalResult matchAndRewrite(OpT op,
27*79eb406aSFrank Schlimbach                                 mlir::PatternRewriter &b) const override {
28*79eb406aSFrank Schlimbach     auto mRef = op.getRef();
29*79eb406aSFrank Schlimbach     if (mRef.getType().hasStaticShape()) {
30*79eb406aSFrank Schlimbach       return mlir::failure();
31*79eb406aSFrank Schlimbach     }
32*79eb406aSFrank Schlimbach     auto defOp = mRef.getDefiningOp();
33*79eb406aSFrank Schlimbach     if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
34*79eb406aSFrank Schlimbach       return mlir::failure();
35*79eb406aSFrank Schlimbach     }
36*79eb406aSFrank Schlimbach     auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
37*79eb406aSFrank Schlimbach     if (!src.getType().hasStaticShape()) {
38*79eb406aSFrank Schlimbach       return mlir::failure();
39*79eb406aSFrank Schlimbach     }
40*79eb406aSFrank Schlimbach     op.getRefMutable().assign(src);
41*79eb406aSFrank Schlimbach     return mlir::success();
42*79eb406aSFrank Schlimbach   }
43*79eb406aSFrank Schlimbach };
44*79eb406aSFrank Schlimbach } // namespace
45*79eb406aSFrank Schlimbach 
46*79eb406aSFrank Schlimbach void mlir::mpi::SendOp::getCanonicalizationPatterns(
47*79eb406aSFrank Schlimbach     mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
48*79eb406aSFrank Schlimbach   results.add<FoldCast<mlir::mpi::SendOp>>(context);
49*79eb406aSFrank Schlimbach }
50*79eb406aSFrank Schlimbach 
51*79eb406aSFrank Schlimbach void mlir::mpi::RecvOp::getCanonicalizationPatterns(
52*79eb406aSFrank Schlimbach     mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
53*79eb406aSFrank Schlimbach   results.add<FoldCast<mlir::mpi::RecvOp>>(context);
54*79eb406aSFrank Schlimbach }
55*79eb406aSFrank Schlimbach 
56b334664fSAnton Lydike //===----------------------------------------------------------------------===//
57b334664fSAnton Lydike // TableGen'd op method definitions
58b334664fSAnton Lydike //===----------------------------------------------------------------------===//
59b334664fSAnton Lydike 
60b334664fSAnton Lydike #define GET_OP_CLASSES
61b334664fSAnton Lydike #include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
62