xref: /llvm-project/flang/lib/Optimizer/CodeGen/CodeGenOpenMP.cpp (revision 206fad0e218e83799e49ca15545d997c6c5e8a03)
1 //===-- CodeGenOpenMP.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/CodeGenOpenMP.h"
14 
15 #include "flang/Optimizer/Builder/FIRBuilder.h"
16 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
17 #include "flang/Optimizer/CodeGen/CodeGen.h"
18 #include "flang/Optimizer/Dialect/FIRDialect.h"
19 #include "flang/Optimizer/Dialect/FIROps.h"
20 #include "flang/Optimizer/Dialect/FIRType.h"
21 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
22 #include "flang/Optimizer/Support/FatalError.h"
23 #include "flang/Optimizer/Support/InternalNames.h"
24 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
25 #include "mlir/Conversion/LLVMCommon/Pattern.h"
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 
31 using namespace fir;
32 
33 #define DEBUG_TYPE "flang-codegen-openmp"
34 
35 // fir::LLVMTypeConverter for converting to LLVM IR dialect types.
36 #include "flang/Optimizer/CodeGen/TypeConverter.h"
37 
38 namespace {
39 /// A pattern that converts the region arguments in a single-region OpenMP
40 /// operation to the LLVM dialect. The body of the region is not modified and is
41 /// expected to either be processed by the conversion infrastructure or already
42 /// contain ops compatible with LLVM dialect types.
43 template <typename OpType>
44 class OpenMPFIROpConversion : public mlir::ConvertOpToLLVMPattern<OpType> {
45 public:
46   explicit OpenMPFIROpConversion(const fir::LLVMTypeConverter &lowering)
47       : mlir::ConvertOpToLLVMPattern<OpType>(lowering) {}
48 
49   const fir::LLVMTypeConverter &lowerTy() const {
50     return *static_cast<const fir::LLVMTypeConverter *>(
51         this->getTypeConverter());
52   }
53 };
54 
55 // FIR Op specific conversion for MapInfoOp that overwrites the default OpenMP
56 // Dialect lowering, this allows FIR specific lowering of types, required for
57 // descriptors of allocatables currently.
58 struct MapInfoOpConversion
59     : public OpenMPFIROpConversion<mlir::omp::MapInfoOp> {
60   using OpenMPFIROpConversion::OpenMPFIROpConversion;
61 
62   llvm::LogicalResult
63   matchAndRewrite(mlir::omp::MapInfoOp curOp, OpAdaptor adaptor,
64                   mlir::ConversionPatternRewriter &rewriter) const override {
65     const mlir::TypeConverter *converter = getTypeConverter();
66     llvm::SmallVector<mlir::Type> resTypes;
67     if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
68       return mlir::failure();
69 
70     llvm::SmallVector<mlir::NamedAttribute> newAttrs;
71     mlir::omp::MapInfoOp newOp;
72     for (mlir::NamedAttribute attr : curOp->getAttrs()) {
73       if (auto typeAttr = mlir::dyn_cast<mlir::TypeAttr>(attr.getValue())) {
74         mlir::Type newAttr;
75         if (fir::isTypeWithDescriptor(typeAttr.getValue())) {
76           newAttr = lowerTy().convertBoxTypeAsStruct(
77               mlir::cast<fir::BaseBoxType>(typeAttr.getValue()));
78         } else {
79           newAttr = converter->convertType(typeAttr.getValue());
80         }
81         newAttrs.emplace_back(attr.getName(), mlir::TypeAttr::get(newAttr));
82       } else {
83         newAttrs.push_back(attr);
84       }
85     }
86 
87     rewriter.replaceOpWithNewOp<mlir::omp::MapInfoOp>(
88         curOp, resTypes, adaptor.getOperands(), newAttrs);
89 
90     return mlir::success();
91   }
92 };
93 } // namespace
94 
95 void fir::populateOpenMPFIRToLLVMConversionPatterns(
96     const LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns) {
97   patterns.add<MapInfoOpConversion>(converter);
98 }
99