xref: /llvm-project/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp (revision 599c73990532333e62edf8ba19a5302b543f976f)
1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10 
11 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
13 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21 #include "mlir/Pass/Pass.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 /// A pattern that converts the region arguments in a single-region OpenMP
32 /// operation to the LLVM dialect. The body of the region is not modified and is
33 /// expected to either be processed by the conversion infrastructure or already
34 /// contain ops compatible with LLVM dialect types.
35 template <typename OpType>
36 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
37   using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
38 
39   LogicalResult
40   matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
41                   ConversionPatternRewriter &rewriter) const override {
42     auto newOp = rewriter.create<OpType>(
43         curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
44     rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
45                                 newOp.getRegion().end());
46     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
47                                            *this->getTypeConverter())))
48       return failure();
49 
50     rewriter.eraseOp(curOp);
51     return success();
52   }
53 };
54 
55 template <typename T>
56 struct RegionLessOpWithVarOperandsConversion
57     : public ConvertOpToLLVMPattern<T> {
58   using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
59   LogicalResult
60   matchAndRewrite(T curOp, typename T::Adaptor adaptor,
61                   ConversionPatternRewriter &rewriter) const override {
62     const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
63     SmallVector<Type> resTypes;
64     if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
65       return failure();
66     SmallVector<Value> convertedOperands;
67     assert(curOp.getNumVariableOperands() ==
68                curOp.getOperation()->getNumOperands() &&
69            "unexpected non-variable operands");
70     for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
71       Value originalVariableOperand = curOp.getVariableOperand(idx);
72       if (!originalVariableOperand)
73         return failure();
74       if (isa<MemRefType>(originalVariableOperand.getType())) {
75         // TODO: Support memref type in variable operands
76         return rewriter.notifyMatchFailure(curOp,
77                                            "memref is not supported yet");
78       }
79       convertedOperands.emplace_back(adaptor.getOperands()[idx]);
80     }
81 
82     rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
83                                    curOp->getAttrs());
84     return success();
85   }
86 };
87 
88 template <typename T>
89 struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
90   using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
91   LogicalResult
92   matchAndRewrite(T curOp, typename T::Adaptor adaptor,
93                   ConversionPatternRewriter &rewriter) const override {
94     const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
95     SmallVector<Type> resTypes;
96     if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
97       return failure();
98     SmallVector<Value> convertedOperands;
99     assert(curOp.getNumVariableOperands() ==
100                curOp.getOperation()->getNumOperands() &&
101            "unexpected non-variable operands");
102     for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
103       Value originalVariableOperand = curOp.getVariableOperand(idx);
104       if (!originalVariableOperand)
105         return failure();
106       if (isa<MemRefType>(originalVariableOperand.getType())) {
107         // TODO: Support memref type in variable operands
108         return rewriter.notifyMatchFailure(curOp,
109                                            "memref is not supported yet");
110       }
111       convertedOperands.emplace_back(adaptor.getOperands()[idx]);
112     }
113     auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
114                                     curOp->getAttrs());
115     rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
116                                 newOp.getRegion().end());
117     if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
118                                            *this->getTypeConverter())))
119       return failure();
120 
121     rewriter.eraseOp(curOp);
122     return success();
123   }
124 };
125 
126 template <typename T>
127 struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
128   using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
129   LogicalResult
130   matchAndRewrite(T curOp, typename T::Adaptor adaptor,
131                   ConversionPatternRewriter &rewriter) const override {
132     const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
133     SmallVector<Type> resTypes;
134     if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
135       return failure();
136 
137     rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
138                                    curOp->getAttrs());
139     return success();
140   }
141 };
142 
143 struct AtomicReadOpConversion
144     : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
145   using ConvertOpToLLVMPattern<omp::AtomicReadOp>::ConvertOpToLLVMPattern;
146   LogicalResult
147   matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
148                   ConversionPatternRewriter &rewriter) const override {
149     const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
150     Type curElementType = curOp.getElementType();
151     auto newOp = rewriter.create<omp::AtomicReadOp>(
152         curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
153     TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
154     newOp.setElementTypeAttr(typeAttr);
155     rewriter.eraseOp(curOp);
156     return success();
157   }
158 };
159 
160 struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
161   using ConvertOpToLLVMPattern<omp::MapInfoOp>::ConvertOpToLLVMPattern;
162   LogicalResult
163   matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
164                   ConversionPatternRewriter &rewriter) const override {
165     const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
166 
167     SmallVector<Type> resTypes;
168     if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
169       return failure();
170 
171     // Copy attributes of the curOp except for the typeAttr which should
172     // be converted
173     SmallVector<NamedAttribute> newAttrs;
174     for (NamedAttribute attr : curOp->getAttrs()) {
175       if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
176         Type newAttr = converter->convertType(typeAttr.getValue());
177         newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
178       } else {
179         newAttrs.push_back(attr);
180       }
181     }
182 
183     rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
184         curOp, resTypes, adaptor.getOperands(), newAttrs);
185     return success();
186   }
187 };
188 
189 template <typename OpType>
190 struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
191   using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
192 
193   void forwardOpAttrs(OpType curOp, OpType newOp) const {}
194 
195   LogicalResult
196   matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
197                   ConversionPatternRewriter &rewriter) const override {
198     auto newOp = rewriter.create<OpType>(
199         curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
200         TypeAttr::get(this->getTypeConverter()->convertType(
201             curOp.getTypeAttr().getValue())));
202     forwardOpAttrs(curOp, newOp);
203 
204     for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
205       rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
206                                   newOp.getRegion(idx).end());
207       if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx),
208                                              *this->getTypeConverter())))
209         return failure();
210     }
211 
212     rewriter.eraseOp(curOp);
213     return success();
214   }
215 };
216 
217 template <>
218 void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
219     omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
220   newOp.setDataSharingType(curOp.getDataSharingType());
221 }
222 } // namespace
223 
224 void mlir::configureOpenMPToLLVMConversionLegality(
225     ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
226   target.addDynamicallyLegalOp<
227       omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
228       omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
229       omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
230       omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
231       omp::YieldOp>([&](Operation *op) {
232     return typeConverter.isLegal(op->getOperandTypes()) &&
233            typeConverter.isLegal(op->getResultTypes());
234   });
235   target.addDynamicallyLegalOp<
236       omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareReductionOp,
237       omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, omp::MasterOp,
238       omp::OrderedRegionOp, omp::ParallelOp, omp::PrivateClauseOp,
239       omp::SectionOp, omp::SectionsOp, omp::SimdOp, omp::SingleOp,
240       omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp, omp::TaskloopOp,
241       omp::TaskOp, omp::TeamsOp, omp::WsloopOp>([&](Operation *op) {
242     return std::all_of(op->getRegions().begin(), op->getRegions().end(),
243                        [&](Region &region) {
244                          return typeConverter.isLegal(&region);
245                        }) &&
246            typeConverter.isLegal(op->getOperandTypes()) &&
247            typeConverter.isLegal(op->getResultTypes());
248   });
249 }
250 
251 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
252                                                   RewritePatternSet &patterns) {
253   // This type is allowed when converting OpenMP to LLVM Dialect, it carries
254   // bounds information for map clauses and the operation and type are
255   // discarded on lowering to LLVM-IR from the OpenMP dialect.
256   converter.addConversion(
257       [&](omp::MapBoundsType type) -> Type { return type; });
258 
259   patterns.add<
260       AtomicReadOpConversion, MapInfoOpConversion,
261       MultiRegionOpConversion<omp::DeclareReductionOp>,
262       MultiRegionOpConversion<omp::PrivateClauseOp>,
263       RegionLessOpConversion<omp::CancellationPointOp>,
264       RegionLessOpConversion<omp::CancelOp>,
265       RegionLessOpConversion<omp::CriticalDeclareOp>,
266       RegionLessOpConversion<omp::OrderedOp>,
267       RegionLessOpConversion<omp::TargetEnterDataOp>,
268       RegionLessOpConversion<omp::TargetExitDataOp>,
269       RegionLessOpConversion<omp::TargetUpdateOp>,
270       RegionLessOpConversion<omp::YieldOp>,
271       RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
272       RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
273       RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>,
274       RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>,
275       RegionOpConversion<omp::AtomicCaptureOp>,
276       RegionOpConversion<omp::CriticalOp>,
277       RegionOpConversion<omp::DistributeOp>,
278       RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>,
279       RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>,
280       RegionOpConversion<omp::OrderedRegionOp>,
281       RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>,
282       RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>,
283       RegionOpConversion<omp::SingleOp>, RegionOpConversion<omp::TargetDataOp>,
284       RegionOpConversion<omp::TargetOp>, RegionOpConversion<omp::TaskgroupOp>,
285       RegionOpConversion<omp::TaskloopOp>, RegionOpConversion<omp::TaskOp>,
286       RegionOpConversion<omp::TeamsOp>, RegionOpConversion<omp::WsloopOp>,
287       RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>>(converter);
288 }
289 
290 namespace {
291 struct ConvertOpenMPToLLVMPass
292     : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> {
293   using Base::Base;
294 
295   void runOnOperation() override;
296 };
297 } // namespace
298 
299 void ConvertOpenMPToLLVMPass::runOnOperation() {
300   auto module = getOperation();
301 
302   // Convert to OpenMP operations with LLVM IR dialect
303   RewritePatternSet patterns(&getContext());
304   LLVMTypeConverter converter(&getContext());
305   arith::populateArithToLLVMConversionPatterns(converter, patterns);
306   cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
307   cf::populateAssertToLLVMConversionPattern(converter, patterns);
308   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
309   populateFuncToLLVMConversionPatterns(converter, patterns);
310   populateOpenMPToLLVMConversionPatterns(converter, patterns);
311 
312   LLVMConversionTarget target(getContext());
313   target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp,
314                     omp::TaskyieldOp, omp::TerminatorOp>();
315   configureOpenMPToLLVMConversionLegality(target, converter);
316   if (failed(applyPartialConversion(module, target, std::move(patterns))))
317     signalPassFailure();
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // ConvertToLLVMPatternInterface implementation
322 //===----------------------------------------------------------------------===//
323 namespace {
324 /// Implement the interface to convert OpenMP to LLVM.
325 struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
326   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
327   void loadDependentDialects(MLIRContext *context) const final {
328     context->loadDialect<LLVM::LLVMDialect>();
329   }
330 
331   /// Hook for derived dialect interface to provide conversion patterns
332   /// and mark dialect legal for the conversion target.
333   void populateConvertToLLVMConversionPatterns(
334       ConversionTarget &target, LLVMTypeConverter &typeConverter,
335       RewritePatternSet &patterns) const final {
336     configureOpenMPToLLVMConversionLegality(target, typeConverter);
337     populateOpenMPToLLVMConversionPatterns(typeConverter, patterns);
338   }
339 };
340 } // namespace
341 
342 void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry &registry) {
343   registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
344     dialect->addInterfaces<OpenMPToLLVMDialectInterface>();
345   });
346 }
347