xref: /llvm-project/flang/lib/Optimizer/Dialect/FIRDialect.cpp (revision 26a0b277369adc31b162b1cc38b1a712bc10c1a0)
1 //===-- FIRDialect.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIRAttr.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 
20 using namespace fir;
21 
22 namespace {
23 /// This class defines the interface for handling inlining of FIR calls.
24 struct FIRInlinerInterface : public mlir::DialectInlinerInterface {
25   using DialectInlinerInterface::DialectInlinerInterface;
26 
27   bool isLegalToInline(mlir::Operation *call, mlir::Operation *callable,
28                        bool wouldBeCloned) const final {
29     return fir::canLegallyInline(call, callable, wouldBeCloned);
30   }
31 
32   /// This hook checks to see if the operation `op` is legal to inline into the
33   /// given region `reg`.
34   bool isLegalToInline(mlir::Operation *op, mlir::Region *reg,
35                        bool wouldBeCloned, mlir::IRMapping &map) const final {
36     return fir::canLegallyInline(op, reg, wouldBeCloned, map);
37   }
38 
39   /// This hook is called when a terminator operation has been inlined.
40   /// We handle the return (a Fortran FUNCTION) by replacing the values
41   /// previously returned by the call operation with the operands of the
42   /// return.
43   void handleTerminator(mlir::Operation *op,
44                         mlir::ValueRange valuesToRepl) const final {
45     auto returnOp = llvm::cast<mlir::func::ReturnOp>(op);
46     assert(returnOp.getNumOperands() == valuesToRepl.size());
47     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
48       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
49   }
50 
51   mlir::Operation *materializeCallConversion(mlir::OpBuilder &builder,
52                                              mlir::Value input,
53                                              mlir::Type resultType,
54                                              mlir::Location loc) const final {
55     return builder.create<fir::ConvertOp>(loc, resultType, input);
56   }
57 };
58 } // namespace
59 
60 fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
61     : mlir::Dialect("fir", ctx, mlir::TypeID::get<FIROpsDialect>()) {
62   getContext()->loadDialect<mlir::LLVM::LLVMDialect>();
63   registerTypes();
64   registerAttributes();
65   addOperations<
66 #define GET_OP_LIST
67 #include "flang/Optimizer/Dialect/FIROps.cpp.inc"
68       >();
69   registerOpExternalInterfaces();
70   addInterfaces<FIRInlinerInterface>();
71 }
72 
73 // anchor the class vtable to this compilation unit
74 fir::FIROpsDialect::~FIROpsDialect() {
75   // do nothing
76 }
77 
78 mlir::Type fir::FIROpsDialect::parseType(mlir::DialectAsmParser &parser) const {
79   return parseFirType(const_cast<FIROpsDialect *>(this), parser);
80 }
81 
82 void fir::FIROpsDialect::printType(mlir::Type ty,
83                                    mlir::DialectAsmPrinter &p) const {
84   return printFirType(const_cast<FIROpsDialect *>(this), ty, p);
85 }
86 
87 mlir::Attribute
88 fir::FIROpsDialect::parseAttribute(mlir::DialectAsmParser &parser,
89                                    mlir::Type type) const {
90   return parseFirAttribute(const_cast<FIROpsDialect *>(this), parser, type);
91 }
92 
93 void fir::FIROpsDialect::printAttribute(mlir::Attribute attr,
94                                         mlir::DialectAsmPrinter &p) const {
95   printFirAttribute(const_cast<FIROpsDialect *>(this), attr, p);
96 }
97