xref: /llvm-project/flang/lib/Optimizer/Dialect/FIRDialect.cpp (revision 6da728ad9945e070c0860cc2841deb148a7d76b4)
1 //===-- FIRDialect.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIRAttr.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 
19 using namespace fir;
20 
21 namespace {
22 /// This class defines the interface for handling inlining of FIR calls.
23 struct FIRInlinerInterface : public mlir::DialectInlinerInterface {
24   using DialectInlinerInterface::DialectInlinerInterface;
25 
26   bool isLegalToInline(mlir::Operation *call, mlir::Operation *callable,
27                        bool wouldBeCloned) const final {
28     return fir::canLegallyInline(call, callable, wouldBeCloned);
29   }
30 
31   /// This hook checks to see if the operation `op` is legal to inline into the
32   /// given region `reg`.
33   bool isLegalToInline(mlir::Operation *op, mlir::Region *reg,
34                        bool wouldBeCloned,
35                        mlir::BlockAndValueMapping &map) const final {
36     return fir::canLegallyInline(op, reg, wouldBeCloned, map);
37   }
38 
39   /// This hook is called when a terminator operation has been inlined.
40   /// We handle the return (a Fortran FUNCTION) by replacing the values
41   /// previously returned by the call operation with the operands of the
42   /// return.
43   void handleTerminator(mlir::Operation *op,
44                         llvm::ArrayRef<mlir::Value> valuesToRepl) const final {
45     auto returnOp = cast<mlir::ReturnOp>(op);
46     assert(returnOp.getNumOperands() == valuesToRepl.size());
47     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
48       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
49   }
50 
51   mlir::Operation *materializeCallConversion(mlir::OpBuilder &builder,
52                                              mlir::Value input,
53                                              mlir::Type resultType,
54                                              mlir::Location loc) const final {
55     return builder.create<fir::ConvertOp>(loc, resultType, input);
56   }
57 };
58 } // namespace
59 
60 fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
61     : mlir::Dialect("fir", ctx, mlir::TypeID::get<FIROpsDialect>()) {
62   registerTypes();
63   registerAttributes();
64   addOperations<
65 #define GET_OP_LIST
66 #include "flang/Optimizer/Dialect/FIROps.cpp.inc"
67       >();
68   addInterfaces<FIRInlinerInterface>();
69 }
70 
71 // anchor the class vtable to this compilation unit
72 fir::FIROpsDialect::~FIROpsDialect() {
73   // do nothing
74 }
75 
76 mlir::Type fir::FIROpsDialect::parseType(mlir::DialectAsmParser &parser) const {
77   return parseFirType(const_cast<FIROpsDialect *>(this), parser);
78 }
79 
80 void fir::FIROpsDialect::printType(mlir::Type ty,
81                                    mlir::DialectAsmPrinter &p) const {
82   return printFirType(const_cast<FIROpsDialect *>(this), ty, p);
83 }
84 
85 mlir::Attribute
86 fir::FIROpsDialect::parseAttribute(mlir::DialectAsmParser &parser,
87                                    mlir::Type type) const {
88   return parseFirAttribute(const_cast<FIROpsDialect *>(this), parser, type);
89 }
90 
91 void fir::FIROpsDialect::printAttribute(mlir::Attribute attr,
92                                         mlir::DialectAsmPrinter &p) const {
93   printFirAttribute(const_cast<FIROpsDialect *>(this), attr, p);
94 }
95