1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" 10 11 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" 12 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 13 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" 14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 17 #include "mlir/Conversion/LLVMCommon/Pattern.h" 18 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 21 #include "mlir/Pass/Pass.h" 22 23 namespace mlir { 24 #define GEN_PASS_DEF_CONVERTOPENMPTOLLVMPASS 25 #include "mlir/Conversion/Passes.h.inc" 26 } // namespace mlir 27 28 using namespace mlir; 29 30 namespace { 31 /// A pattern that converts the region arguments in a single-region OpenMP 32 /// operation to the LLVM dialect. The body of the region is not modified and is 33 /// expected to either be processed by the conversion infrastructure or already 34 /// contain ops compatible with LLVM dialect types. 35 template <typename OpType> 36 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> { 37 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern; 38 39 LogicalResult 40 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor, 41 ConversionPatternRewriter &rewriter) const override { 42 auto newOp = rewriter.create<OpType>( 43 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); 44 rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(), 45 newOp.getRegion().end()); 46 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), 47 *this->getTypeConverter()))) 48 return failure(); 49 50 rewriter.eraseOp(curOp); 51 return success(); 52 } 53 }; 54 55 template <typename T> 56 struct RegionLessOpWithVarOperandsConversion 57 : public ConvertOpToLLVMPattern<T> { 58 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern; 59 LogicalResult 60 matchAndRewrite(T curOp, typename T::Adaptor adaptor, 61 ConversionPatternRewriter &rewriter) const override { 62 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); 63 SmallVector<Type> resTypes; 64 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) 65 return failure(); 66 SmallVector<Value> convertedOperands; 67 assert(curOp.getNumVariableOperands() == 68 curOp.getOperation()->getNumOperands() && 69 "unexpected non-variable operands"); 70 for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { 71 Value originalVariableOperand = curOp.getVariableOperand(idx); 72 if (!originalVariableOperand) 73 return failure(); 74 if (isa<MemRefType>(originalVariableOperand.getType())) { 75 // TODO: Support memref type in variable operands 76 return rewriter.notifyMatchFailure(curOp, 77 "memref is not supported yet"); 78 } 79 convertedOperands.emplace_back(adaptor.getOperands()[idx]); 80 } 81 82 rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands, 83 curOp->getAttrs()); 84 return success(); 85 } 86 }; 87 88 template <typename T> 89 struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> { 90 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern; 91 LogicalResult 92 matchAndRewrite(T curOp, typename T::Adaptor adaptor, 93 ConversionPatternRewriter &rewriter) const override { 94 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); 95 SmallVector<Type> resTypes; 96 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) 97 return failure(); 98 SmallVector<Value> convertedOperands; 99 assert(curOp.getNumVariableOperands() == 100 curOp.getOperation()->getNumOperands() && 101 "unexpected non-variable operands"); 102 for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { 103 Value originalVariableOperand = curOp.getVariableOperand(idx); 104 if (!originalVariableOperand) 105 return failure(); 106 if (isa<MemRefType>(originalVariableOperand.getType())) { 107 // TODO: Support memref type in variable operands 108 return rewriter.notifyMatchFailure(curOp, 109 "memref is not supported yet"); 110 } 111 convertedOperands.emplace_back(adaptor.getOperands()[idx]); 112 } 113 auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands, 114 curOp->getAttrs()); 115 rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(), 116 newOp.getRegion().end()); 117 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), 118 *this->getTypeConverter()))) 119 return failure(); 120 121 rewriter.eraseOp(curOp); 122 return success(); 123 } 124 }; 125 126 template <typename T> 127 struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> { 128 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern; 129 LogicalResult 130 matchAndRewrite(T curOp, typename T::Adaptor adaptor, 131 ConversionPatternRewriter &rewriter) const override { 132 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); 133 SmallVector<Type> resTypes; 134 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) 135 return failure(); 136 137 rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(), 138 curOp->getAttrs()); 139 return success(); 140 } 141 }; 142 143 struct AtomicReadOpConversion 144 : public ConvertOpToLLVMPattern<omp::AtomicReadOp> { 145 using ConvertOpToLLVMPattern<omp::AtomicReadOp>::ConvertOpToLLVMPattern; 146 LogicalResult 147 matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor, 148 ConversionPatternRewriter &rewriter) const override { 149 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); 150 Type curElementType = curOp.getElementType(); 151 auto newOp = rewriter.create<omp::AtomicReadOp>( 152 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); 153 TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType)); 154 newOp.setElementTypeAttr(typeAttr); 155 rewriter.eraseOp(curOp); 156 return success(); 157 } 158 }; 159 160 struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> { 161 using ConvertOpToLLVMPattern<omp::MapInfoOp>::ConvertOpToLLVMPattern; 162 LogicalResult 163 matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor, 164 ConversionPatternRewriter &rewriter) const override { 165 const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); 166 167 SmallVector<Type> resTypes; 168 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) 169 return failure(); 170 171 // Copy attributes of the curOp except for the typeAttr which should 172 // be converted 173 SmallVector<NamedAttribute> newAttrs; 174 for (NamedAttribute attr : curOp->getAttrs()) { 175 if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) { 176 Type newAttr = converter->convertType(typeAttr.getValue()); 177 newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr)); 178 } else { 179 newAttrs.push_back(attr); 180 } 181 } 182 183 rewriter.replaceOpWithNewOp<omp::MapInfoOp>( 184 curOp, resTypes, adaptor.getOperands(), newAttrs); 185 return success(); 186 } 187 }; 188 189 template <typename OpType> 190 struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> { 191 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern; 192 193 void forwardOpAttrs(OpType curOp, OpType newOp) const {} 194 195 LogicalResult 196 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor, 197 ConversionPatternRewriter &rewriter) const override { 198 auto newOp = rewriter.create<OpType>( 199 curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(), 200 TypeAttr::get(this->getTypeConverter()->convertType( 201 curOp.getTypeAttr().getValue()))); 202 forwardOpAttrs(curOp, newOp); 203 204 for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) { 205 rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx), 206 newOp.getRegion(idx).end()); 207 if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx), 208 *this->getTypeConverter()))) 209 return failure(); 210 } 211 212 rewriter.eraseOp(curOp); 213 return success(); 214 } 215 }; 216 217 template <> 218 void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs( 219 omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const { 220 newOp.setDataSharingType(curOp.getDataSharingType()); 221 } 222 } // namespace 223 224 void mlir::configureOpenMPToLLVMConversionLegality( 225 ConversionTarget &target, const LLVMTypeConverter &typeConverter) { 226 target.addDynamicallyLegalOp< 227 omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp, 228 omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp, 229 omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp, 230 omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp, 231 omp::YieldOp>([&](Operation *op) { 232 return typeConverter.isLegal(op->getOperandTypes()) && 233 typeConverter.isLegal(op->getResultTypes()); 234 }); 235 target.addDynamicallyLegalOp< 236 omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareReductionOp, 237 omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, omp::MasterOp, 238 omp::OrderedRegionOp, omp::ParallelOp, omp::PrivateClauseOp, 239 omp::SectionOp, omp::SectionsOp, omp::SimdOp, omp::SingleOp, 240 omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp, omp::TaskloopOp, 241 omp::TaskOp, omp::TeamsOp, omp::WsloopOp>([&](Operation *op) { 242 return std::all_of(op->getRegions().begin(), op->getRegions().end(), 243 [&](Region ®ion) { 244 return typeConverter.isLegal(®ion); 245 }) && 246 typeConverter.isLegal(op->getOperandTypes()) && 247 typeConverter.isLegal(op->getResultTypes()); 248 }); 249 } 250 251 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, 252 RewritePatternSet &patterns) { 253 // This type is allowed when converting OpenMP to LLVM Dialect, it carries 254 // bounds information for map clauses and the operation and type are 255 // discarded on lowering to LLVM-IR from the OpenMP dialect. 256 converter.addConversion( 257 [&](omp::MapBoundsType type) -> Type { return type; }); 258 259 patterns.add< 260 AtomicReadOpConversion, MapInfoOpConversion, 261 MultiRegionOpConversion<omp::DeclareReductionOp>, 262 MultiRegionOpConversion<omp::PrivateClauseOp>, 263 RegionLessOpConversion<omp::CancellationPointOp>, 264 RegionLessOpConversion<omp::CancelOp>, 265 RegionLessOpConversion<omp::CriticalDeclareOp>, 266 RegionLessOpConversion<omp::OrderedOp>, 267 RegionLessOpConversion<omp::TargetEnterDataOp>, 268 RegionLessOpConversion<omp::TargetExitDataOp>, 269 RegionLessOpConversion<omp::TargetUpdateOp>, 270 RegionLessOpConversion<omp::YieldOp>, 271 RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>, 272 RegionLessOpWithVarOperandsConversion<omp::FlushOp>, 273 RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>, 274 RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>, 275 RegionOpConversion<omp::AtomicCaptureOp>, 276 RegionOpConversion<omp::CriticalOp>, 277 RegionOpConversion<omp::DistributeOp>, 278 RegionOpConversion<omp::LoopNestOp>, RegionOpConversion<omp::LoopOp>, 279 RegionOpConversion<omp::MaskedOp>, RegionOpConversion<omp::MasterOp>, 280 RegionOpConversion<omp::OrderedRegionOp>, 281 RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::SectionOp>, 282 RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SimdOp>, 283 RegionOpConversion<omp::SingleOp>, RegionOpConversion<omp::TargetDataOp>, 284 RegionOpConversion<omp::TargetOp>, RegionOpConversion<omp::TaskgroupOp>, 285 RegionOpConversion<omp::TaskloopOp>, RegionOpConversion<omp::TaskOp>, 286 RegionOpConversion<omp::TeamsOp>, RegionOpConversion<omp::WsloopOp>, 287 RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>>(converter); 288 } 289 290 namespace { 291 struct ConvertOpenMPToLLVMPass 292 : public impl::ConvertOpenMPToLLVMPassBase<ConvertOpenMPToLLVMPass> { 293 using Base::Base; 294 295 void runOnOperation() override; 296 }; 297 } // namespace 298 299 void ConvertOpenMPToLLVMPass::runOnOperation() { 300 auto module = getOperation(); 301 302 // Convert to OpenMP operations with LLVM IR dialect 303 RewritePatternSet patterns(&getContext()); 304 LLVMTypeConverter converter(&getContext()); 305 arith::populateArithToLLVMConversionPatterns(converter, patterns); 306 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 307 cf::populateAssertToLLVMConversionPattern(converter, patterns); 308 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns); 309 populateFuncToLLVMConversionPatterns(converter, patterns); 310 populateOpenMPToLLVMConversionPatterns(converter, patterns); 311 312 LLVMConversionTarget target(getContext()); 313 target.addLegalOp<omp::BarrierOp, omp::FlushOp, omp::TaskwaitOp, 314 omp::TaskyieldOp, omp::TerminatorOp>(); 315 configureOpenMPToLLVMConversionLegality(target, converter); 316 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 317 signalPassFailure(); 318 } 319 320 //===----------------------------------------------------------------------===// 321 // ConvertToLLVMPatternInterface implementation 322 //===----------------------------------------------------------------------===// 323 namespace { 324 /// Implement the interface to convert OpenMP to LLVM. 325 struct OpenMPToLLVMDialectInterface : public ConvertToLLVMPatternInterface { 326 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; 327 void loadDependentDialects(MLIRContext *context) const final { 328 context->loadDialect<LLVM::LLVMDialect>(); 329 } 330 331 /// Hook for derived dialect interface to provide conversion patterns 332 /// and mark dialect legal for the conversion target. 333 void populateConvertToLLVMConversionPatterns( 334 ConversionTarget &target, LLVMTypeConverter &typeConverter, 335 RewritePatternSet &patterns) const final { 336 configureOpenMPToLLVMConversionLegality(target, typeConverter); 337 populateOpenMPToLLVMConversionPatterns(typeConverter, patterns); 338 } 339 }; 340 } // namespace 341 342 void mlir::registerConvertOpenMPToLLVMInterface(DialectRegistry ®istry) { 343 registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) { 344 dialect->addInterfaces<OpenMPToLLVMDialectInterface>(); 345 }); 346 } 347