xref: /llvm-project/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp (revision c5de6611ce10b8ecf573f601b5f12de60424897d)
1 //===- FunctionFiltering.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to filter out functions intended for the host
10 // when compiling for the device and vice versa.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "flang/Optimizer/Dialect/FIRDialect.h"
15 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
16 #include "flang/Optimizer/OpenMP/Passes.h"
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
20 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "llvm/ADT/SmallVector.h"
23 
24 namespace flangomp {
25 #define GEN_PASS_DEF_FUNCTIONFILTERINGPASS
26 #include "flang/Optimizer/OpenMP/Passes.h.inc"
27 } // namespace flangomp
28 
29 using namespace mlir;
30 
31 namespace {
32 class FunctionFilteringPass
33     : public flangomp::impl::FunctionFilteringPassBase<FunctionFilteringPass> {
34 public:
35   FunctionFilteringPass() = default;
36 
37   void runOnOperation() override {
38     MLIRContext *context = &getContext();
39     OpBuilder opBuilder(context);
40     auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
41     if (!op || !op.getIsTargetDevice())
42       return;
43 
44     op->walk<WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
45       // Do not filter functions with target regions inside, because they have
46       // to be available for both host and device so that regular and reverse
47       // offloading can be supported.
48       bool hasTargetRegion =
49           funcOp
50               ->walk<WalkOrder::PreOrder>(
51                   [&](omp::TargetOp) { return WalkResult::interrupt(); })
52               .wasInterrupted();
53 
54       omp::DeclareTargetDeviceType declareType =
55           omp::DeclareTargetDeviceType::host;
56       auto declareTargetOp =
57           dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
58       if (declareTargetOp && declareTargetOp.isDeclareTarget())
59         declareType = declareTargetOp.getDeclareTargetDeviceType();
60 
61       // Filtering a function here means deleting it if it doesn't contain a
62       // target region. Else we explicitly set the omp.declare_target
63       // attribute. The second stage of function filtering at the MLIR to LLVM
64       // IR translation level will remove functions that contain the target
65       // region from the generated llvm IR.
66       if (declareType == omp::DeclareTargetDeviceType::host) {
67         SymbolTable::UseRange funcUses = *funcOp.getSymbolUses(op);
68         for (SymbolTable::SymbolUse use : funcUses) {
69           Operation *callOp = use.getUser();
70           if (auto internalFunc = mlir::dyn_cast<func::FuncOp>(callOp)) {
71             // Do not delete internal procedures holding the symbol of their
72             // Fortran host procedure as attribute.
73             internalFunc->removeAttr(fir::getHostSymbolAttrName());
74             // Set public visibility so that the function is not deleted by MLIR
75             // because unused. Changing it is OK here because the function will
76             // be deleted anyway in the second filtering phase.
77             internalFunc.setVisibility(mlir::SymbolTable::Visibility::Public);
78             continue;
79           }
80           // If the callOp has users then replace them with Undef values.
81           if (!callOp->use_empty()) {
82             SmallVector<Value> undefResults;
83             for (Value res : callOp->getResults()) {
84               opBuilder.setInsertionPoint(callOp);
85               undefResults.emplace_back(
86                   opBuilder.create<fir::UndefOp>(res.getLoc(), res.getType()));
87             }
88             callOp->replaceAllUsesWith(undefResults);
89           }
90           // Remove the callOp
91           callOp->erase();
92         }
93         if (!hasTargetRegion) {
94           funcOp.erase();
95           return WalkResult::skip();
96         }
97         if (declareTargetOp)
98           declareTargetOp.setDeclareTarget(declareType,
99                                            omp::DeclareTargetCaptureClause::to);
100       }
101       return WalkResult::advance();
102     });
103   }
104 };
105 } // namespace
106