xref: /llvm-project/flang/lib/Optimizer/Dialect/FIRDialect.cpp (revision aa6b47cdaf3cddc70b7af33c1edbda502ecb3d05)
1 //===-- FIRDialect.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIRAttr.h"
15 #include "flang/Optimizer/Dialect/FIROps.h"
16 #include "flang/Optimizer/Dialect/FIRType.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 
19 using namespace fir;
20 
21 namespace {
22 /// This class defines the interface for handling inlining of FIR calls.
23 struct FIRInlinerInterface : public mlir::DialectInlinerInterface {
24   using DialectInlinerInterface::DialectInlinerInterface;
25 
26   bool isLegalToInline(mlir::Operation *call, mlir::Operation *callable,
27                        bool wouldBeCloned) const final {
28     return fir::canLegallyInline(call, callable, wouldBeCloned);
29   }
30 
31   /// This hook checks to see if the operation `op` is legal to inline into the
32   /// given region `reg`.
33   bool isLegalToInline(mlir::Operation *op, mlir::Region *reg,
34                        bool wouldBeCloned, mlir::IRMapping &map) const final {
35     return fir::canLegallyInline(op, reg, wouldBeCloned, map);
36   }
37 
38   /// This hook is called when a terminator operation has been inlined.
39   /// We handle the return (a Fortran FUNCTION) by replacing the values
40   /// previously returned by the call operation with the operands of the
41   /// return.
42   void handleTerminator(mlir::Operation *op,
43                         llvm::ArrayRef<mlir::Value> valuesToRepl) const final {
44     auto returnOp = llvm::cast<mlir::func::ReturnOp>(op);
45     assert(returnOp.getNumOperands() == valuesToRepl.size());
46     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
47       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
48   }
49 
50   mlir::Operation *materializeCallConversion(mlir::OpBuilder &builder,
51                                              mlir::Value input,
52                                              mlir::Type resultType,
53                                              mlir::Location loc) const final {
54     return builder.create<fir::ConvertOp>(loc, resultType, input);
55   }
56 };
57 } // namespace
58 
59 fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
60     : mlir::Dialect("fir", ctx, mlir::TypeID::get<FIROpsDialect>()) {
61   registerTypes();
62   registerAttributes();
63   addOperations<
64 #define GET_OP_LIST
65 #include "flang/Optimizer/Dialect/FIROps.cpp.inc"
66       >();
67   addInterfaces<FIRInlinerInterface>();
68 }
69 
70 // anchor the class vtable to this compilation unit
71 fir::FIROpsDialect::~FIROpsDialect() {
72   // do nothing
73 }
74 
75 mlir::Type fir::FIROpsDialect::parseType(mlir::DialectAsmParser &parser) const {
76   return parseFirType(const_cast<FIROpsDialect *>(this), parser);
77 }
78 
79 void fir::FIROpsDialect::printType(mlir::Type ty,
80                                    mlir::DialectAsmPrinter &p) const {
81   return printFirType(const_cast<FIROpsDialect *>(this), ty, p);
82 }
83 
84 mlir::Attribute
85 fir::FIROpsDialect::parseAttribute(mlir::DialectAsmParser &parser,
86                                    mlir::Type type) const {
87   return parseFirAttribute(const_cast<FIROpsDialect *>(this), parser, type);
88 }
89 
90 void fir::FIROpsDialect::printAttribute(mlir::Attribute attr,
91                                         mlir::DialectAsmPrinter &p) const {
92   printFirAttribute(const_cast<FIROpsDialect *>(this), attr, p);
93 }
94